Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import json | |
| from datetime import datetime | |
| import re | |
| import pandas as pd | |
| import streamlit as st | |
| from huggingface_hub import InferenceClient | |
| # ------------------ Setup PublicAI Inference Client ------------------ | |
| def get_publicai_client(): | |
| token = os.getenv("publicai") # Hugging Face space secret | |
| if not token: | |
| raise ValueError("Missing PublicAI token. Set 'PUBLICAI' env variable in the space secrets.") | |
| return InferenceClient(api_key=token) | |
| def _call_llama_space(prompt: str) -> str: | |
| try: | |
| client = get_publicai_client() | |
| response = client.chat.completions.create( | |
| model="swiss-ai/apertus-8b-instruct", | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| return response.choices[0].message["content"].strip() | |
| except Exception as e: | |
| return f"Error calling PublicAI model: {e}" | |
| # ------------------ Utils ------------------ | |
| try: | |
| from .utils import ( | |
| generate_synthetic_transactions, | |
| filter_transactions, | |
| compute_aggregations, | |
| build_time_series_chart, | |
| build_category_bar_chart, | |
| build_payment_method_pie_chart, | |
| ) | |
| except Exception: | |
| from utils import ( | |
| generate_synthetic_transactions, | |
| filter_transactions, | |
| compute_aggregations, | |
| build_time_series_chart, | |
| build_category_bar_chart, | |
| build_payment_method_pie_chart, | |
| ) | |
| # ------------------ Streamlit Setup ------------------ | |
| st.set_page_config(page_title="AI Spending Analyser", page_icon="💳", layout="wide") | |
| # ------------------ Helpers for Markdown/Spikes ------------------ | |
| def strip_markdown(text: str) -> str: | |
| if not isinstance(text, str): | |
| return text | |
| text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) | |
| text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) | |
| text = re.sub(r"`(.*?)`", r"\1", text) | |
| text = re.sub(r"~~(.*?)~~", r"\1", text) | |
| return text | |
| def markdown_bold_to_html(text: str) -> str: | |
| return re.sub(r"\*\*(.*?)\*\*", r"<strong>\1</strong>", text) | |
| def sanitize_llm_text_for_users(text: str) -> str: | |
| if not isinstance(text, str): | |
| return text | |
| text = re.sub(r"\bIsSpike\b", "spike indicator", text, flags=re.IGNORECASE) | |
| text = re.sub(r"\bTrue\b", "yes", text) | |
| text = re.sub(r"\bFalse\b", "no", text) | |
| return text | |
| # ------------------ Session State ------------------ | |
| def init_session_state(): | |
| if "data" not in st.session_state: | |
| st.session_state.data = generate_synthetic_transactions(n_rows=900, seed=42) | |
| if "filters" not in st.session_state: | |
| min_date = st.session_state.data["Date"].min() | |
| max_date = st.session_state.data["Date"].max() | |
| st.session_state.filters = {"date_range": (min_date, max_date), "categories": [], "merchant_query": ""} | |
| # ------------------ Header ------------------ | |
| def render_header(): | |
| st.markdown(""" | |
| <div style='display: flex; align-items: baseline; gap: 15px; margin-bottom: 20px;'> | |
| <div style='font-size: 80px; color: #00AEEF; font-weight: bold; line-height: 1;'>^</div> | |
| <div style='font-size: 36px; color: #697089; font-weight: 500; line-height: 1;'>AI Spending Analyser</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| def render_assistant_banner(): return | |
| def render_chat_fab(): return | |
| # ------------------ Sidebar ------------------ | |
| def render_sidebar(df: pd.DataFrame): | |
| st.sidebar.header("Filters") | |
| min_d, max_d = df["Date"].min(), df["Date"].max() | |
| col1, col2 = st.sidebar.columns(2) | |
| with col1: | |
| from_date = st.date_input("From", value=min_d.date(), min_value=min_d.date(), max_value=max_d.date()) | |
| with col2: | |
| to_date = st.date_input("To", value=max_d.date(), min_value=min_d.date(), max_value=max_d.date()) | |
| if from_date > to_date: | |
| st.sidebar.error("From date cannot be after To date") | |
| from_date, to_date = min_d.date(), max_d.date() | |
| categories = st.sidebar.multiselect("Category", options=sorted(df["Category"].unique()), default=[]) | |
| merchant_query = st.sidebar.text_input("Merchant search", value="", placeholder="Type a merchant name…") | |
| st.sidebar.divider() | |
| st.sidebar.header("AI") | |
| summary_mode = st.sidebar.radio("Summary", options=["Concise", "Detailed"], index=0, horizontal=True) | |
| engine = st.sidebar.selectbox("Engine", options=["swiss-ai/apertus-8b-instruct"], index=0) | |
| st.sidebar.divider() | |
| st.sidebar.header("Anomalies & Highlights") | |
| show_spikes = st.sidebar.toggle("Show spike markers", value=True) | |
| large_tx_threshold = st.sidebar.slider("Large transaction threshold (£)", 50, 1000, 250, step=25) | |
| if st.sidebar.button("Regenerate"): | |
| st.session_state.data = generate_synthetic_transactions(n_rows=900) | |
| st.session_state.filters = { | |
| "date_range": (datetime.combine(from_date, datetime.min.time()), datetime.combine(to_date, datetime.max.time())), | |
| "categories": categories, | |
| "merchant_query": merchant_query.strip(), | |
| "summary_mode": summary_mode, | |
| "engine": engine, | |
| "show_spikes": show_spikes, | |
| "large_tx_threshold": large_tx_threshold, | |
| } | |
| # ------------------ Metrics ------------------ | |
| def render_metrics(agg: dict): | |
| col1, col2, col3, col4 = st.columns(4) | |
| for col, label, value in zip( | |
| [col1, col2, col3, col4], | |
| ["Total Value", "Avg Monthly", "Max Transaction", "Min Transaction"], | |
| [agg["total_spend"], agg["avg_monthly_spend"], agg["max_transaction"]["Amount"], agg["min_transaction"]["Amount"]], | |
| ): | |
| gbp_whole, gbp_pence = f"{value:,.2f}".split(".") | |
| col.markdown(f""" | |
| <div class='metric-card' title='{label}: £{value:,.2f}'> | |
| <div class='metric-label'>{label}</div> | |
| <div class='kpi-value'><span style='font-size:0.9em;'>£</span><span style='font-size:2.3em;font-weight:bold;'>{gbp_whole}</span><span style='font-size:1em;'>{'.'+gbp_pence}</span></div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ------------------ ISA Widget ------------------ | |
| def render_isa_widget(current_spend: float, allowance: float): | |
| used = min(current_spend, allowance) | |
| remaining = max(allowance - used, 0) | |
| percent = 0 if allowance <= 0 else int((used / allowance) * 100) | |
| st.markdown("<div class='isa-widget'>", unsafe_allow_html=True) | |
| st.subheader("ISA allowance") | |
| st.markdown(f"<div class='progress'><div style='width:{percent}%;'></div></div>", unsafe_allow_html=True) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| used_whole, used_pence = f"{used:,.2f}".split(".") | |
| st.markdown(f"<div><span style='font-weight:600;'>USED</span><br/><span style='font-size:2em;font-weight:bold;'>£{used_whole}</span><span style='font-size:1em;'>.{used_pence}</span></div>", unsafe_allow_html=True) | |
| with col2: | |
| rem_whole, rem_pence = f"{remaining:,.2f}".split(".") | |
| st.markdown(f"<div><span style='font-weight:600;color:rgba(255,255,255,0.8)'>REMAINING</span><br/><span style='font-size:2em;font-weight:bold;'>£{rem_whole}</span><span style='font-size:1em;'>.{rem_pence}</span></div>", unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # ------------------ Charts ------------------ | |
| def render_charts(filtered_df: pd.DataFrame, agg: dict, template: str, show_spikes: bool): | |
| t1, t2, t3 = st.tabs(["Trend", "By Category", "Payment Methods"]) | |
| chart_style = "chart-card" | |
| with t1: | |
| fig = build_time_series_chart(filtered_df, template=template, spike_overlay=agg["spikes"] if show_spikes else None) | |
| st.markdown(f"<div class='{chart_style}'>", unsafe_allow_html=True) | |
| st.plotly_chart(fig, use_container_width=True, responsive=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| with t2: | |
| st.caption("Tip: Select categories in the sidebar to compare their total spend.") | |
| fig = build_category_bar_chart(agg["spend_per_category"], template=template) | |
| st.markdown(f"<div class='{chart_style}'>", unsafe_allow_html=True) | |
| st.plotly_chart(fig, use_container_width=True, responsive=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| with t3: | |
| fig = build_payment_method_pie_chart(agg["spend_per_payment"], template=template) | |
| st.markdown(f"<div class='{chart_style}'>", unsafe_allow_html=True) | |
| st.plotly_chart(fig, use_container_width=True, responsive=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # ------------------ Heuristic AI Summary ------------------ | |
| def heuristic_summary(agg: dict, mode: str) -> str: | |
| total = agg.get("total_spend", 0) | |
| avg_month = agg.get("avg_monthly_spend", 0) | |
| spend_per_category = agg.get("spend_per_category", {}) | |
| df_spikes = agg.get("spikes", pd.DataFrame()) | |
| days = st.session_state.filters.get("date_range", (None, None)) | |
| num_days = (days[1] - days[0]).days + 1 if all(days) else "unknown" | |
| if df_spikes.empty or len(df_spikes) == 0: | |
| spike_text = "No unusual spending spikes were detected during this period." | |
| else: | |
| try: | |
| top_spike = df_spikes.loc[df_spikes["Amount"].idxmax()] | |
| spike_date = pd.to_datetime(top_spike["Date"]).strftime("%Y-%m-%d") | |
| spike_amt = top_spike["Amount"] | |
| spike_text = f"A noticeable spending spike occurred on {spike_date}, reaching £{spike_amt:,.2f}, which exceeded the 28-day rolling average." | |
| except Exception: | |
| spike_text = "Some days showed higher-than-normal spending compared to the 28-day rolling average." | |
| text = f"Over a period of {num_days} days, the total spend was £{total:,.2f} with an average monthly spend of £{avg_month:,.2f}. {spike_text} Here’s a breakdown of the spending patterns:\n\n" | |
| if spend_per_category: | |
| sorted_cats = sorted(spend_per_category.items(), key=lambda x: x[1], reverse=True) | |
| for i, (cat, val) in enumerate(sorted_cats[:8], 1): | |
| pct = val / total * 100 if total else 0 | |
| text += f"{i}. {cat}: £{val:,.2f} ({pct:.1f}% of total spend). " | |
| text += " Overall, this summary helps identify key categories and spending trends over the selected period." | |
| return text | |
| # ------------------ AI Summary ------------------ | |
| def render_ai_summary(agg: dict, mode: str, engine: str): | |
| st.subheader("AI Summary") | |
| placeholder = st.empty() | |
| placeholder.markdown("<div class='ai-card'>Generating summary…</div>", unsafe_allow_html=True) | |
| if engine == "swiss-ai/apertus-8b-instruct": | |
| prompt = ( | |
| f"Generate a {mode.lower()} human-readable spending summary in clear sentences. " | |
| f"Use only the pound symbol (£) for all amounts. Ensure proper spacing between words and numbers and correct date formatting (YYYY-MM-DD). " | |
| f"Avoid exposing internal column names or raw boolean values. " | |
| f"Data: {json.dumps({'total_spend': agg.get('total_spend'), 'avg_monthly_spend': agg.get('avg_monthly_spend'), 'top_categories': agg.get('spend_per_category'), 'spikes': agg.get('spikes')}, default=str)}" | |
| ) | |
| raw_text = _call_llama_space(prompt) | |
| raw_text = raw_text.replace("$", "£") | |
| raw_text = re.sub(r"\b(\d{4})(\d{2})(\d{2})\b", r"\1-\2-\3", raw_text) | |
| raw_text = re.sub(r"[^\x00-\x7F]+", "", raw_text) | |
| raw_text = sanitize_llm_text_for_users(raw_text) | |
| raw_text = re.sub(r"(\d)([A-Za-z£])", r"\1 \2", raw_text) | |
| raw_text = re.sub(r"([A-Za-z£])(\d)", r"\1 \2", raw_text) | |
| clean_body = strip_markdown(raw_text).strip() | |
| else: | |
| clean_body = heuristic_summary(agg, mode).strip() | |
| clean_body = sanitize_llm_text_for_users(clean_body) | |
| header_md = f"The total spending over the period was **£{agg.get('total_spend', 0):,.2f}**, averaging **£{agg.get('avg_monthly_spend', 0):,.2f} per month**. " | |
| combined = header_md + " " + clean_body | |
| html_safe = markdown_bold_to_html(combined) | |
| placeholder.markdown( | |
| f""" | |
| <div class='ai-card' style='font-size:0.9rem;color:#E0E0E0;word-wrap: break-word; white-space: pre-line; position: relative;'> | |
| {html_safe} | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # ------------------ Main ------------------ | |
| def main(): | |
| init_session_state() | |
| st.markdown(""" | |
| <style> | |
| .metric-card, .isa-widget, .ai-card, .chart-card { | |
| padding: 15px; | |
| border-radius: 12px; | |
| margin-bottom: 10px; | |
| cursor: pointer; | |
| box-shadow: 0 10px 30px rgba(0,174,239,0.2); | |
| transform: scale(1.03); | |
| background-color: rgba(0,174,239,0.1); | |
| transition: all 0.3s ease; | |
| } | |
| .metric-card:hover { box-shadow: 0 12px 40px rgba(0,174,239,0.35); transform: scale(1.05); } | |
| .ai-card { | |
| background-color: rgba(0,204,153,0.12); | |
| transform: scale(1.01); | |
| box-shadow: 0 6px 20px rgba(0,204,153,0.15); | |
| position: relative; | |
| } | |
| .ai-card::before { | |
| content: ''; | |
| position: absolute; | |
| top: 0; left: 0; right: 0; bottom: 0; | |
| border-radius: 12px; | |
| box-shadow: 0 0 40px rgba(0, 204, 153, 0.3); | |
| z-index: -1; | |
| transition: all 0.3s ease; | |
| } | |
| .ai-card:hover::before { box-shadow: 0 0 50px rgba(0, 204, 153, 0.45); } | |
| .isa-widget {background-color: rgba(0,174,239,0.08); transform: scale(1.02); box-shadow: 0 8px 25px rgba(0,174,239,0.15);} | |
| .chart-card {background-color: rgba(255,255,255,0.05); transform: scale(1.01); box-shadow: 0 6px 25px rgba(0,174,239,0.15);} | |
| table, th, td { | |
| border:1px solid #555; | |
| padding:5px; | |
| text-align:left; | |
| background-color:#222; | |
| color:#E0E0E0; | |
| } | |
| th {font-weight:600;} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| render_header() | |
| render_assistant_banner() | |
| render_chat_fab() | |
| render_sidebar(st.session_state.data) | |
| filters = st.session_state.filters | |
| filtered = filter_transactions( | |
| st.session_state.data, | |
| date_range=filters["date_range"], | |
| categories=filters["categories"], | |
| merchant_query=filters["merchant_query"], | |
| ) | |
| if filtered.empty: | |
| st.info("No data for selected filters. Adjust filters to see insights.") | |
| return | |
| agg = compute_aggregations(filtered) | |
| st.markdown("<div class='card'>", unsafe_allow_html=True) | |
| render_metrics(agg) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| with st.expander("Allowance widget"): | |
| allowance = st.number_input("Annual allowance (£)", min_value=0, value=20000, step=500) | |
| render_isa_widget(current_spend=float(agg["total_spend"]), allowance=float(allowance)) | |
| template = "plotly_dark" | |
| render_charts(filtered, agg, template, show_spikes=filters["show_spikes"]) | |
| render_ai_summary(agg, mode=filters["summary_mode"], engine=filters["engine"]) | |
| threshold = filters["large_tx_threshold"] | |
| large_df = filtered[filtered["Amount"] >= threshold].sort_values("Amount", ascending=False) | |
| with st.expander(f"Show large transactions (≥ £{threshold}) [{len(large_df)}]"): | |
| st.dataframe(large_df, use_container_width=True, hide_index=True) | |
| st.divider() | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.caption("Download filtered data") | |
| csv = filtered.to_csv(index=False).encode("utf-8") | |
| st.download_button("Download CSV", csv, file_name="transactions_filtered.csv", mime="text/csv") | |
| with col2: | |
| st.caption("Dataset size") | |
| st.write(f"{len(filtered):,} rows") | |
| if __name__ == "__main__": | |
| main() |