Ekow24's picture
Update app.py
e934937 verified
# app.py
import os
import json
from datetime import datetime
import re
import pandas as pd
import streamlit as st
from huggingface_hub import InferenceClient
# ------------------ Setup PublicAI Inference Client ------------------
@st.cache_resource
def get_publicai_client():
token = os.getenv("publicai") # Hugging Face space secret
if not token:
raise ValueError("Missing PublicAI token. Set 'PUBLICAI' env variable in the space secrets.")
return InferenceClient(api_key=token)
def _call_llama_space(prompt: str) -> str:
try:
client = get_publicai_client()
response = client.chat.completions.create(
model="swiss-ai/apertus-8b-instruct",
messages=[{"role": "user", "content": prompt}],
)
return response.choices[0].message["content"].strip()
except Exception as e:
return f"Error calling PublicAI model: {e}"
# ------------------ Utils ------------------
try:
from .utils import (
generate_synthetic_transactions,
filter_transactions,
compute_aggregations,
build_time_series_chart,
build_category_bar_chart,
build_payment_method_pie_chart,
)
except Exception:
from utils import (
generate_synthetic_transactions,
filter_transactions,
compute_aggregations,
build_time_series_chart,
build_category_bar_chart,
build_payment_method_pie_chart,
)
# ------------------ Streamlit Setup ------------------
st.set_page_config(page_title="AI Spending Analyser", page_icon="💳", layout="wide")
# ------------------ Helpers for Markdown/Spikes ------------------
def strip_markdown(text: str) -> str:
if not isinstance(text, str):
return text
text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text)
text = re.sub(r"(\*|_)(.*?)\1", r"\2", text)
text = re.sub(r"`(.*?)`", r"\1", text)
text = re.sub(r"~~(.*?)~~", r"\1", text)
return text
def markdown_bold_to_html(text: str) -> str:
return re.sub(r"\*\*(.*?)\*\*", r"<strong>\1</strong>", text)
def sanitize_llm_text_for_users(text: str) -> str:
if not isinstance(text, str):
return text
text = re.sub(r"\bIsSpike\b", "spike indicator", text, flags=re.IGNORECASE)
text = re.sub(r"\bTrue\b", "yes", text)
text = re.sub(r"\bFalse\b", "no", text)
return text
# ------------------ Session State ------------------
def init_session_state():
if "data" not in st.session_state:
st.session_state.data = generate_synthetic_transactions(n_rows=900, seed=42)
if "filters" not in st.session_state:
min_date = st.session_state.data["Date"].min()
max_date = st.session_state.data["Date"].max()
st.session_state.filters = {"date_range": (min_date, max_date), "categories": [], "merchant_query": ""}
# ------------------ Header ------------------
def render_header():
st.markdown("""
<div style='display: flex; align-items: baseline; gap: 15px; margin-bottom: 20px;'>
<div style='font-size: 80px; color: #00AEEF; font-weight: bold; line-height: 1;'>^</div>
<div style='font-size: 36px; color: #697089; font-weight: 500; line-height: 1;'>AI Spending Analyser</div>
</div>
""", unsafe_allow_html=True)
def render_assistant_banner(): return
def render_chat_fab(): return
# ------------------ Sidebar ------------------
def render_sidebar(df: pd.DataFrame):
st.sidebar.header("Filters")
min_d, max_d = df["Date"].min(), df["Date"].max()
col1, col2 = st.sidebar.columns(2)
with col1:
from_date = st.date_input("From", value=min_d.date(), min_value=min_d.date(), max_value=max_d.date())
with col2:
to_date = st.date_input("To", value=max_d.date(), min_value=min_d.date(), max_value=max_d.date())
if from_date > to_date:
st.sidebar.error("From date cannot be after To date")
from_date, to_date = min_d.date(), max_d.date()
categories = st.sidebar.multiselect("Category", options=sorted(df["Category"].unique()), default=[])
merchant_query = st.sidebar.text_input("Merchant search", value="", placeholder="Type a merchant name…")
st.sidebar.divider()
st.sidebar.header("AI")
summary_mode = st.sidebar.radio("Summary", options=["Concise", "Detailed"], index=0, horizontal=True)
engine = st.sidebar.selectbox("Engine", options=["swiss-ai/apertus-8b-instruct"], index=0)
st.sidebar.divider()
st.sidebar.header("Anomalies & Highlights")
show_spikes = st.sidebar.toggle("Show spike markers", value=True)
large_tx_threshold = st.sidebar.slider("Large transaction threshold (£)", 50, 1000, 250, step=25)
if st.sidebar.button("Regenerate"):
st.session_state.data = generate_synthetic_transactions(n_rows=900)
st.session_state.filters = {
"date_range": (datetime.combine(from_date, datetime.min.time()), datetime.combine(to_date, datetime.max.time())),
"categories": categories,
"merchant_query": merchant_query.strip(),
"summary_mode": summary_mode,
"engine": engine,
"show_spikes": show_spikes,
"large_tx_threshold": large_tx_threshold,
}
# ------------------ Metrics ------------------
def render_metrics(agg: dict):
col1, col2, col3, col4 = st.columns(4)
for col, label, value in zip(
[col1, col2, col3, col4],
["Total Value", "Avg Monthly", "Max Transaction", "Min Transaction"],
[agg["total_spend"], agg["avg_monthly_spend"], agg["max_transaction"]["Amount"], agg["min_transaction"]["Amount"]],
):
gbp_whole, gbp_pence = f"{value:,.2f}".split(".")
col.markdown(f"""
<div class='metric-card' title='{label}: £{value:,.2f}'>
<div class='metric-label'>{label}</div>
<div class='kpi-value'><span style='font-size:0.9em;'>£</span><span style='font-size:2.3em;font-weight:bold;'>{gbp_whole}</span><span style='font-size:1em;'>{'.'+gbp_pence}</span></div>
</div>
""", unsafe_allow_html=True)
# ------------------ ISA Widget ------------------
def render_isa_widget(current_spend: float, allowance: float):
used = min(current_spend, allowance)
remaining = max(allowance - used, 0)
percent = 0 if allowance <= 0 else int((used / allowance) * 100)
st.markdown("<div class='isa-widget'>", unsafe_allow_html=True)
st.subheader("ISA allowance")
st.markdown(f"<div class='progress'><div style='width:{percent}%;'></div></div>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
used_whole, used_pence = f"{used:,.2f}".split(".")
st.markdown(f"<div><span style='font-weight:600;'>USED</span><br/><span style='font-size:2em;font-weight:bold;'>£{used_whole}</span><span style='font-size:1em;'>.{used_pence}</span></div>", unsafe_allow_html=True)
with col2:
rem_whole, rem_pence = f"{remaining:,.2f}".split(".")
st.markdown(f"<div><span style='font-weight:600;color:rgba(255,255,255,0.8)'>REMAINING</span><br/><span style='font-size:2em;font-weight:bold;'>£{rem_whole}</span><span style='font-size:1em;'>.{rem_pence}</span></div>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# ------------------ Charts ------------------
def render_charts(filtered_df: pd.DataFrame, agg: dict, template: str, show_spikes: bool):
t1, t2, t3 = st.tabs(["Trend", "By Category", "Payment Methods"])
chart_style = "chart-card"
with t1:
fig = build_time_series_chart(filtered_df, template=template, spike_overlay=agg["spikes"] if show_spikes else None)
st.markdown(f"<div class='{chart_style}'>", unsafe_allow_html=True)
st.plotly_chart(fig, use_container_width=True, responsive=True)
st.markdown("</div>", unsafe_allow_html=True)
with t2:
st.caption("Tip: Select categories in the sidebar to compare their total spend.")
fig = build_category_bar_chart(agg["spend_per_category"], template=template)
st.markdown(f"<div class='{chart_style}'>", unsafe_allow_html=True)
st.plotly_chart(fig, use_container_width=True, responsive=True)
st.markdown("</div>", unsafe_allow_html=True)
with t3:
fig = build_payment_method_pie_chart(agg["spend_per_payment"], template=template)
st.markdown(f"<div class='{chart_style}'>", unsafe_allow_html=True)
st.plotly_chart(fig, use_container_width=True, responsive=True)
st.markdown("</div>", unsafe_allow_html=True)
# ------------------ Heuristic AI Summary ------------------
def heuristic_summary(agg: dict, mode: str) -> str:
total = agg.get("total_spend", 0)
avg_month = agg.get("avg_monthly_spend", 0)
spend_per_category = agg.get("spend_per_category", {})
df_spikes = agg.get("spikes", pd.DataFrame())
days = st.session_state.filters.get("date_range", (None, None))
num_days = (days[1] - days[0]).days + 1 if all(days) else "unknown"
if df_spikes.empty or len(df_spikes) == 0:
spike_text = "No unusual spending spikes were detected during this period."
else:
try:
top_spike = df_spikes.loc[df_spikes["Amount"].idxmax()]
spike_date = pd.to_datetime(top_spike["Date"]).strftime("%Y-%m-%d")
spike_amt = top_spike["Amount"]
spike_text = f"A noticeable spending spike occurred on {spike_date}, reaching £{spike_amt:,.2f}, which exceeded the 28-day rolling average."
except Exception:
spike_text = "Some days showed higher-than-normal spending compared to the 28-day rolling average."
text = f"Over a period of {num_days} days, the total spend was £{total:,.2f} with an average monthly spend of £{avg_month:,.2f}. {spike_text} Here’s a breakdown of the spending patterns:\n\n"
if spend_per_category:
sorted_cats = sorted(spend_per_category.items(), key=lambda x: x[1], reverse=True)
for i, (cat, val) in enumerate(sorted_cats[:8], 1):
pct = val / total * 100 if total else 0
text += f"{i}. {cat}: £{val:,.2f} ({pct:.1f}% of total spend). "
text += " Overall, this summary helps identify key categories and spending trends over the selected period."
return text
# ------------------ AI Summary ------------------
def render_ai_summary(agg: dict, mode: str, engine: str):
st.subheader("AI Summary")
placeholder = st.empty()
placeholder.markdown("<div class='ai-card'>Generating summary…</div>", unsafe_allow_html=True)
if engine == "swiss-ai/apertus-8b-instruct":
prompt = (
f"Generate a {mode.lower()} human-readable spending summary in clear sentences. "
f"Use only the pound symbol (£) for all amounts. Ensure proper spacing between words and numbers and correct date formatting (YYYY-MM-DD). "
f"Avoid exposing internal column names or raw boolean values. "
f"Data: {json.dumps({'total_spend': agg.get('total_spend'), 'avg_monthly_spend': agg.get('avg_monthly_spend'), 'top_categories': agg.get('spend_per_category'), 'spikes': agg.get('spikes')}, default=str)}"
)
raw_text = _call_llama_space(prompt)
raw_text = raw_text.replace("$", "£")
raw_text = re.sub(r"\b(\d{4})(\d{2})(\d{2})\b", r"\1-\2-\3", raw_text)
raw_text = re.sub(r"[^\x00-\x7F]+", "", raw_text)
raw_text = sanitize_llm_text_for_users(raw_text)
raw_text = re.sub(r"(\d)([A-Za-z£])", r"\1 \2", raw_text)
raw_text = re.sub(r"([A-Za-z£])(\d)", r"\1 \2", raw_text)
clean_body = strip_markdown(raw_text).strip()
else:
clean_body = heuristic_summary(agg, mode).strip()
clean_body = sanitize_llm_text_for_users(clean_body)
header_md = f"The total spending over the period was **£{agg.get('total_spend', 0):,.2f}**, averaging **£{agg.get('avg_monthly_spend', 0):,.2f} per month**. "
combined = header_md + " " + clean_body
html_safe = markdown_bold_to_html(combined)
placeholder.markdown(
f"""
<div class='ai-card' style='font-size:0.9rem;color:#E0E0E0;word-wrap: break-word; white-space: pre-line; position: relative;'>
{html_safe}
</div>
""",
unsafe_allow_html=True
)
# ------------------ Main ------------------
def main():
init_session_state()
st.markdown("""
<style>
.metric-card, .isa-widget, .ai-card, .chart-card {
padding: 15px;
border-radius: 12px;
margin-bottom: 10px;
cursor: pointer;
box-shadow: 0 10px 30px rgba(0,174,239,0.2);
transform: scale(1.03);
background-color: rgba(0,174,239,0.1);
transition: all 0.3s ease;
}
.metric-card:hover { box-shadow: 0 12px 40px rgba(0,174,239,0.35); transform: scale(1.05); }
.ai-card {
background-color: rgba(0,204,153,0.12);
transform: scale(1.01);
box-shadow: 0 6px 20px rgba(0,204,153,0.15);
position: relative;
}
.ai-card::before {
content: '';
position: absolute;
top: 0; left: 0; right: 0; bottom: 0;
border-radius: 12px;
box-shadow: 0 0 40px rgba(0, 204, 153, 0.3);
z-index: -1;
transition: all 0.3s ease;
}
.ai-card:hover::before { box-shadow: 0 0 50px rgba(0, 204, 153, 0.45); }
.isa-widget {background-color: rgba(0,174,239,0.08); transform: scale(1.02); box-shadow: 0 8px 25px rgba(0,174,239,0.15);}
.chart-card {background-color: rgba(255,255,255,0.05); transform: scale(1.01); box-shadow: 0 6px 25px rgba(0,174,239,0.15);}
table, th, td {
border:1px solid #555;
padding:5px;
text-align:left;
background-color:#222;
color:#E0E0E0;
}
th {font-weight:600;}
</style>
""", unsafe_allow_html=True)
render_header()
render_assistant_banner()
render_chat_fab()
render_sidebar(st.session_state.data)
filters = st.session_state.filters
filtered = filter_transactions(
st.session_state.data,
date_range=filters["date_range"],
categories=filters["categories"],
merchant_query=filters["merchant_query"],
)
if filtered.empty:
st.info("No data for selected filters. Adjust filters to see insights.")
return
agg = compute_aggregations(filtered)
st.markdown("<div class='card'>", unsafe_allow_html=True)
render_metrics(agg)
st.markdown("</div>", unsafe_allow_html=True)
with st.expander("Allowance widget"):
allowance = st.number_input("Annual allowance (£)", min_value=0, value=20000, step=500)
render_isa_widget(current_spend=float(agg["total_spend"]), allowance=float(allowance))
template = "plotly_dark"
render_charts(filtered, agg, template, show_spikes=filters["show_spikes"])
render_ai_summary(agg, mode=filters["summary_mode"], engine=filters["engine"])
threshold = filters["large_tx_threshold"]
large_df = filtered[filtered["Amount"] >= threshold].sort_values("Amount", ascending=False)
with st.expander(f"Show large transactions (≥ £{threshold}) [{len(large_df)}]"):
st.dataframe(large_df, use_container_width=True, hide_index=True)
st.divider()
col1, col2 = st.columns([2, 1])
with col1:
st.caption("Download filtered data")
csv = filtered.to_csv(index=False).encode("utf-8")
st.download_button("Download CSV", csv, file_name="transactions_filtered.csv", mime="text/csv")
with col2:
st.caption("Dataset size")
st.write(f"{len(filtered):,} rows")
if __name__ == "__main__":
main()