rmd826's picture
Upload 9 files
a54f28a verified
# -*- coding: utf-8 -*-
"""
Tax Torpedo Analyzer
====================
Streamlit app: 3 inputs -> torpedo plot + summary -> 3 quick-action buttons.
Uses the analyst's "reference taxable income" x-axis convention:
x_plot = OI - Std. Ded. + 0.85 * SSB
No chatbot UI — the LLM only responds to the three quick-action buttons,
and its response replaces the analysis text under the plot.
"""
from __future__ import annotations
import hmac
import os
import re
import streamlit as st
# ---------------------------------------------------------------------------
# API key (set before any imports that might use it)
# ---------------------------------------------------------------------------
os.environ["GEMINI_API_KEY"] = os.environ.get("GEMINI_API_KEY", "")
from user_session import UserProfile
from chat_orchestrator import ChatOrchestrator
from gemini_tools import calculate_tax_situation, generate_torpedo_plot
# ---------------------------------------------------------------------------
# Password gate
# ---------------------------------------------------------------------------
def check_password() -> bool:
"""Return True if the user has entered the correct password."""
expected = os.environ.get("APP_PASSWORD", "")
if not expected:
return True # No password configured — allow access (local dev)
if st.session_state.get("password_correct", False):
return True
st.markdown(
'<h1 style="text-align:center; color:#1a237e; margin-top:80px;">'
"Tax Torpedo Analyzer</h1>"
'<p style="text-align:center; color:#666; font-size:16px;">'
"Please enter the password to continue.</p>",
unsafe_allow_html=True,
)
password = st.text_input("Password", type="password", key="password_input")
if st.button("Log in", type="primary"):
if hmac.compare_digest(password, expected):
st.session_state["password_correct"] = True
st.rerun()
else:
st.error("Incorrect password. Please try again.")
return False
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
STATUS_MAP = {
"Single": "SGL",
"Married Filing Jointly": "MFJ",
"Head of Household": "HOH",
"Married Filing Separately": "MFS",
}
STATUS_CHOICES = list(STATUS_MAP.keys())
ZONE_BADGE = {
"No-Tax Zone": '<span style="display:inline-block;padding:6px 16px;border-radius:20px;font-weight:700;font-size:15px;margin-right:8px;background:#c8e6c9;color:#2e7d32;">No-Tax Zone</span>',
"High-Tax Zone": '<span style="display:inline-block;padding:6px 16px;border-radius:20px;font-weight:700;font-size:15px;margin-right:8px;background:#ffcdd2;color:#c62828;">High-Tax Zone</span>',
"Same-Old Zone": '<span style="display:inline-block;padding:6px 16px;border-radius:20px;font-weight:700;font-size:15px;margin-right:8px;background:#e1bee7;color:#4a148c;">Same-Old Zone</span>',
}
# ---------------------------------------------------------------------------
# Custom CSS
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
<style>
.key-numbers-band {
background: #f0f4ff;
border: 1px solid #c5cae9;
border-radius: 12px;
padding: 12px 20px;
margin: 16px 0;
font-size: 16px;
line-height: 1.6;
}
.key-numbers-band b {
color: #1a237e;
}
.disclaimer {
text-align: center;
font-size: 13px;
color: #999;
padding: 20px;
margin-top: 20px;
border-top: 1px solid #eee;
}
.analysis-text {
font-size: 16px;
line-height: 1.7;
color: #333;
padding: 8px 0 16px 0;
}
</style>
"""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _escape_dollars(text: str) -> str:
"""Escape bare $ signs so Streamlit doesn't render them as LaTeX.
Replaces $ that are NOT already escaped (\\$) and NOT part of
a LaTeX block ($...$) with \\$ so they display as literal dollars.
"""
# Replace $<number> patterns (like $40,000 or $1,200.50) with escaped versions
text = re.sub(r'(?<!\\)\$(\d)', r'\\$\1', text)
return text
# Zone badge colors matching the plot shading
_ZONE_COLORS = {
"No-Tax Zone": {"bg": "#c8e6c9", "fg": "#2e7d32"},
"High-Tax Zone": {"bg": "#ffcdd2", "fg": "#c62828"},
"Same-Old Zone": {"bg": "#e1bee7", "fg": "#4a148c"},
}
def _zone_badge_html(zone_name: str, font_size: int = 13) -> str:
"""Return an inline HTML badge for a tax zone name, color-coded to match the plot."""
colors = _ZONE_COLORS.get(zone_name, {"bg": "#e0e0e0", "fg": "#333"})
return (
f'<span style="display:inline-block;padding:3px 10px;border-radius:12px;'
f'font-weight:700;font-size:{font_size}px;'
f'background:{colors["bg"]};color:{colors["fg"]};">'
f'{zone_name}</span>'
)
def _format_key_numbers(tax_result: dict) -> str:
"""Format the key numbers band as HTML for st.markdown."""
zone = tax_result.get("tax_zone", "Unknown")
badge = ZONE_BADGE.get(zone, f'<span style="display:inline-block;padding:6px 16px;border-radius:20px;font-weight:700;font-size:15px;margin-right:8px;background:#e0e0e0;color:#333;">{zone}</span>')
tax = tax_result.get("tax_owed", 0)
taxable_ssb = tax_result.get("taxable_ssb", 0)
pct_ssb = tax_result.get("pct_ssb_taxable", 0)
take_home = tax_result.get("take_home", 0)
zp = tax_result.get("zero_point")
cp = tax_result.get("confluence_point")
mr = tax_result.get("marginal_rate", 0)
er = tax_result.get("effective_rate", 0)
zp_str = f"${zp:,.0f}" if zp is not None else "N/A"
cp_str = f"${cp:,.0f}" if cp is not None else "N/A"
return (
f'<div class="key-numbers-band">'
f'{badge} &nbsp; '
f'<b>Tax Owed:</b> ${tax:,.0f} &nbsp;|&nbsp; '
f'<b>Marginal Rate:</b> {mr:.1f}% &nbsp;|&nbsp; '
f'<b>Effective Rate:</b> {er:.1f}% &nbsp;|&nbsp; '
f'<b>Taxable SSB:</b> ${taxable_ssb:,.0f} ({pct_ssb:.0f}%) &nbsp;|&nbsp; '
f'<b>Take-Home:</b> ${take_home:,.0f}<br>'
f'<b>Zero Point (ref. TI):</b> {zp_str} &nbsp;|&nbsp; '
f'<b>Confluence Point (ref. TI):</b> {cp_str}'
f'</div>'
)
def _format_equations_markdown(tax_result: dict) -> str:
"""Generate a markdown section showing all key numbers with equations.
Uses \\$ to escape dollar signs so Streamlit doesn't render them as LaTeX.
"""
fs_name = tax_result.get("filing_status_name", "Unknown")
fs_code = tax_result.get("filing_status", "")
ssb = tax_result.get("ssb_annual", 0)
oi = tax_result.get("other_income", 0)
pi = tax_result.get("provisional_income", 0)
taxable_ssb = tax_result.get("taxable_ssb", 0)
pct_ssb = tax_result.get("pct_ssb_taxable", 0)
agi = tax_result.get("agi", 0)
std_ded = tax_result.get("standard_deduction", 0)
taxable_inc = tax_result.get("taxable_income", 0)
ref_ti = tax_result.get("ref_taxable_income", 0)
tax = tax_result.get("tax_owed", 0)
eff_rate = tax_result.get("effective_rate", 0)
marg_rate = tax_result.get("marginal_rate", 0)
zone = tax_result.get("tax_zone", "Unknown")
zp = tax_result.get("zero_point")
cp = tax_result.get("confluence_point")
gross = tax_result.get("gross_income", 0)
take_home = tax_result.get("take_home", 0)
zp_str = f"\\${zp:,.0f}" if zp is not None else "N/A"
cp_str = f"\\${cp:,.0f}" if cp is not None else "N/A"
eff_formula = (
f"\\${tax:,.0f} / \\${oi:,.0f} x 100" if oi > 0 else "no other income"
)
lines = [
f"**Tax Calculation Breakdown ({fs_name}, 2016 Rates)**",
"",
"| Item | Formula | Value |",
"|------|---------|-------|",
f"| **Filing Status** | | {fs_name} ({fs_code}) |",
f"| **Social Security Benefit** | | \\${ssb:,.0f} |",
f"| **Other Income** | | \\${oi:,.0f} |",
f"| **Provisional Income** | \\${oi:,.0f} + 0.5 x \\${ssb:,.0f} | **\\${pi:,.0f}** |",
f"| **Taxable SSB** | IRS 3-tier rules on PI | \\${taxable_ssb:,.0f} ({pct_ssb:.0f}% of SSB) |",
f"| **AGI** | \\${oi:,.0f} + \\${taxable_ssb:,.0f} | \\${agi:,.0f} |",
f"| **Standard Deduction** | (incl. personal exemptions) | \\${std_ded:,.0f} |",
f"| **Taxable Income** | max(0, \\${agi:,.0f} - \\${std_ded:,.0f}) | \\${taxable_inc:,.0f} |",
f"| **Ref. Taxable Income** | OI - Std.Ded. + 0.85xSSB | \\${ref_ti:,.0f} |",
f"| **Tax Owed** | bracket calc on \\${taxable_inc:,.0f} | **\\${tax:,.0f}** |",
f"| **Effective Rate** | {eff_formula} | {eff_rate:.1f}% |",
f"| **Marginal Rate** | (tax at \\${oi:,.0f}+\\$100 - tax at \\${oi:,.0f}) / \\$100 | **{marg_rate:.1f}%** |",
f"| **Tax Zone** | | {zone} |",
f"| **Zero Point** | ref. TI where tax first > \\$0 | {zp_str} |",
f"| **Confluence Point** | ref. TI where 85% SSB taxable | {cp_str} |",
f"| **Gross Income** | \\${oi:,.0f} + \\${ssb:,.0f} | \\${gross:,.0f} |",
f"| **Take-Home** | \\${gross:,.0f} - \\${tax:,.0f} | **\\${take_home:,.0f}** |",
]
return "\n".join(lines)
def _fallback_summary(tax_result: dict, error_msg: str = "") -> str:
"""Generate a plain-text summary when Gemini is unavailable."""
zone = tax_result.get("tax_zone", "Unknown")
text = f"**You are in the {zone}.**\n\n"
if zone == "No-Tax Zone":
zp = tax_result.get("zero_point")
text += (
f"You currently owe **\\$0** in federal income tax. "
f"On the chart, you have room up to a reference taxable income of "
f"\\${zp:,.0f} before any tax kicks in."
if zp else
"You currently owe **\\$0** in federal income tax."
)
elif zone == "High-Tax Zone":
text += (
f"You owe **\\${tax_result.get('tax_owed', 0):,.0f}** in federal tax. "
f"Your marginal rate is elevated at "
f"**{tax_result.get('marginal_rate', 0):.1f}%** "
f"due to the tax torpedo effect on your Social Security."
)
else:
text += (
f"You owe **\\${tax_result.get('tax_owed', 0):,.0f}** in federal tax. "
f"You are past the torpedo zone, so normal bracket rates apply."
)
if error_msg:
text += f"\n\n*Note: AI assistant unavailable ({error_msg})*"
return text
def _build_initial_llm_message(filing_status_display: str, ssb: float, other_income: float, tax_result: dict) -> str:
"""Build the initial message to send to Gemini for a summary."""
zp = tax_result.get("zero_point")
cp = tax_result.get("confluence_point")
zp_str = f"${zp:,.0f}" if zp is not None else "N/A"
cp_str = f"${cp:,.0f}" if cp is not None else "N/A"
return (
f"The user just submitted their tax information.\n"
f"Filing Status: {filing_status_display}\n"
f"Social Security Benefit: ${ssb:,.0f}/year\n"
f"Other Taxable Income: ${other_income:,.0f}/year\n\n"
f"Computed results:\n"
f"Tax Zone: {tax_result.get('tax_zone')}\n"
f"Tax Owed: ${tax_result.get('tax_owed', 0):,.0f}\n"
f"Marginal Rate: {tax_result.get('marginal_rate', 0):.1f}%\n"
f"Effective Rate: {tax_result.get('effective_rate', 0):.1f}%\n"
f"Taxable SSB: ${tax_result.get('taxable_ssb', 0):,.0f} "
f"({tax_result.get('pct_ssb_taxable', 0):.0f}%)\n"
f"Zero Point (ref. taxable income): {zp_str}\n"
f"Confluence Point (ref. taxable income): {cp_str}\n"
f"Take-Home: ${tax_result.get('take_home', 0):,.0f}\n\n"
f"CHART EXPLANATION (so you can describe it accurately to the user):\n"
f"The torpedo chart x-axis is a 'reference taxable income' used by tax analysts:\n"
f" reference income = Other Income - Standard Deduction + 0.85 * SSB\n"
f"This axis starts at $0 when Other Income = $0, treating 85% SSB inclusion as the baseline.\n\n"
f"The two curves plotted against this reference income axis are:\n"
f" BLACK DASHED LINE (baseline): taxes computed as if 85% of SSB were always fully taxable "
f"— i.e., normal bracket taxes on the reference income. This is the 'no-torpedo' benchmark.\n"
f" RED LINE (actual): the real tax owed under IRS SSB inclusion rules. In the No-Tax Zone "
f"and early Torpedo Zone, SSB inclusion is 0%-50%, so actual taxable income is LESS than "
f"reference income — the red line starts BELOW the black dashed line.\n\n"
f"The key visual pattern:\n"
f" - In the Torpedo Zone, each $1 of additional OI triggers an extra $0.50-$0.85 of SSB "
f"to become taxable, so the effective marginal rate is 1.5x-1.85x the normal bracket rate. "
f"The red line climbs steeply upward toward the black line.\n"
f" - At the Confluence Point ({cp_str} on this axis), 85% of SSB has become fully "
f"taxable. Actual taxable income now equals the reference income, so the two curves meet.\n"
f" - After the Confluence Point, the lines overlap — normal bracket rates apply again.\n\n"
f"Please provide a clear, friendly, plain-English summary of their "
f"tax situation in 3-5 sentences. Explain what zone they are in, "
f"what the key numbers mean, and give 1-2 actionable insights. "
f"Reference the torpedo chart that is shown above. Keep it concise.\n\n"
f"IMPORTANT: Do NOT call any tools for this initial summary. "
f"All the numbers have already been computed above. Just write "
f"a plain-text summary using these numbers."
)
def _build_improve_message(tax_result: dict) -> str:
"""Build the LLM message for the 'Improve my tax situation' button.
For the High-Tax Zone, asks for a SINGLE combined scenario plot showing
both the zero-point move and the confluence-point move together.
"""
zone = tax_result.get("tax_zone", "")
oi = tax_result.get("other_income", 0)
fs = tax_result.get("filing_status", "MFJ")
ssb = tax_result.get("ssb_annual", 0)
zp_oi = tax_result.get("zero_point_oi")
cp_oi = tax_result.get("confluence_point_oi")
if zone == "No-Tax Zone":
return (
"I'm currently in the No-Tax Zone and pay $0 in federal tax. "
"How much more income can I take before I start paying tax? "
"Show me on the torpedo curve using the scenario comparison tool."
)
elif zone == "High-Tax Zone":
# Build a request for a SINGLE combined plot with both scenarios
new_incomes = []
labels = []
if zp_oi is not None:
new_incomes.append(f"{zp_oi:.0f}")
labels.append(f'"Move to Zero Point (OI=${zp_oi:,.0f})"')
if cp_oi is not None:
new_incomes.append(f"{cp_oi:.0f}")
labels.append(f'"Move Past Torpedo (OI=${cp_oi:,.0f})"')
incomes_str = ", ".join(new_incomes)
labels_str = ", ".join(labels)
return (
f"I'm in the High-Tax Zone with other income of ${oi:,.0f}. "
f"My filing status is {fs} and my SSB is ${ssb:,.0f}.\n\n"
f"Please call generate_scenario_torpedo_plot ONCE with:\n"
f"- old_other_income: {oi}\n"
f"- new_other_incomes: [{incomes_str}]\n"
f"- scenario_labels: [{labels_str}]\n\n"
f"This will show BOTH options on the same plot. "
f"Then explain the dollar trade-offs for each option."
)
else: # Same-Old Zone
return (
f"I'm past the torpedo zone (Same-Old Zone) with other income of ${oi:,.0f}. "
"What optimization strategies make sense at my income level? "
"Show me where I am on the torpedo curve."
)
def _render_scenario_panel(old_kn: dict, new_kn: dict):
"""Render the scenario comparison panel in the right column as a styled HTML block."""
fields = [
("Tax Owed", "tax_owed", "${:,.0f}", True),
("Marginal Rate", "marginal_rate", "{:.1f}%", True),
("Effective Rate", "effective_rate", "{:.1f}%", True),
("Taxable SSB", "taxable_ssb", "${:,.0f}", True),
("Take-Home", "take_home", "${:,.0f}", False),
("Tax Zone", "zone", "{}", None),
("Ref. Taxable Inc.", "taxable_income", "${:,.0f}", None),
]
rows_html = []
for label, key, fmt, lower_is_better in fields:
old_val = old_kn.get(key)
new_val = new_kn.get(key)
if old_val is None or new_val is None:
continue
if isinstance(old_val, (int, float)) and isinstance(new_val, (int, float)):
delta = new_val - old_val
old_str = fmt.format(old_val)
new_str = fmt.format(new_val)
if lower_is_better is not None and delta != 0:
is_good = (delta < 0) if lower_is_better else (delta > 0)
color = "#2e7d32" if is_good else "#c62828"
arrow = "\u2193" if delta < 0 else "\u2191"
delta_str = fmt.format(abs(delta))
delta_html = f'<span style="color:{color};font-weight:bold;">{arrow} {delta_str}</span>'
else:
delta_html = ""
rows_html.append(
f'<div style="padding:6px 0;border-bottom:1px solid #e0e0e0;">'
f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:0.5px;">{label}</div>'
f'<div style="font-size:14px;font-family:inherit;">{old_str} &rarr; {new_str}</div>'
f'{delta_html}</div>'
)
else:
# Use colored badges for zone names
if key == "zone":
old_display = _zone_badge_html(str(old_val), font_size=12)
new_display = _zone_badge_html(str(new_val), font_size=12)
else:
old_display = str(old_val)
new_display = str(new_val)
rows_html.append(
f'<div style="padding:6px 0;border-bottom:1px solid #e0e0e0;">'
f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:0.5px;">{label}</div>'
f'<div style="font-size:14px;font-family:inherit;">{old_display} &rarr; {new_display}</div>'
f'</div>'
)
html = (
'<div style="background:#fafafa;border:1px solid #e0e0e0;border-radius:10px;padding:12px 16px;">'
'<div style="font-size:15px;font-weight:700;color:#1a237e;margin-bottom:8px;">Scenario Comparison</div>'
+ "".join(rows_html) +
'</div>'
)
st.markdown(html, unsafe_allow_html=True)
def _render_multi_scenario_panel(old_kn: dict, all_new_kns: list):
"""Render multiple scenario comparisons side by side with color-coded headers."""
fields = [
("Tax Owed", "tax_owed", "${:,.0f}", True),
("Marginal Rate", "marginal_rate", "{:.1f}%", True),
("Effective Rate", "effective_rate", "{:.1f}%", True),
("Taxable SSB", "taxable_ssb", "${:,.0f}", True),
("Take-Home", "take_home", "${:,.0f}", False),
("Tax Zone", "zone", "{}", None),
]
cols = st.columns(len(all_new_kns))
for col_idx, new_kn in enumerate(all_new_kns):
color = new_kn.get("scenario_color", "#1a237e")
label = new_kn.get("scenario_label", f"Scenario {col_idx + 1}")
rows_html = []
for field_label, key, fmt, lower_is_better in fields:
old_val = old_kn.get(key)
new_val = new_kn.get(key)
if old_val is None or new_val is None:
continue
if isinstance(old_val, (int, float)) and isinstance(new_val, (int, float)):
delta = new_val - old_val
old_str = fmt.format(old_val)
new_str = fmt.format(new_val)
if lower_is_better is not None and delta != 0:
is_good = (delta < 0) if lower_is_better else (delta > 0)
d_color = "#2e7d32" if is_good else "#c62828"
arrow = "\u2193" if delta < 0 else "\u2191"
delta_str = fmt.format(abs(delta))
delta_html = f'<span style="color:{d_color};font-weight:bold;">{arrow} {delta_str}</span>'
else:
delta_html = ""
rows_html.append(
f'<div style="padding:5px 0;border-bottom:1px solid #e0e0e0;">'
f'<div style="font-size:10px;color:#666;text-transform:uppercase;letter-spacing:0.5px;">{field_label}</div>'
f'<div style="font-size:13px;">{old_str} &rarr; {new_str}</div>'
f'{delta_html}</div>'
)
else:
# Use colored badges for zone names
if key == "zone":
old_display = _zone_badge_html(str(old_val), font_size=11)
new_display = _zone_badge_html(str(new_val), font_size=11)
else:
old_display = str(old_val)
new_display = str(new_val)
rows_html.append(
f'<div style="padding:5px 0;border-bottom:1px solid #e0e0e0;">'
f'<div style="font-size:10px;color:#666;text-transform:uppercase;letter-spacing:0.5px;">{field_label}</div>'
f'<div style="font-size:13px;">{old_display} &rarr; {new_display}</div>'
f'</div>'
)
html = (
f'<div style="background:#fafafa;border:2px solid {color};border-radius:10px;padding:10px 12px;">'
f'<div style="font-size:13px;font-weight:700;color:{color};margin-bottom:6px;">{label}</div>'
+ "".join(rows_html) +
'</div>'
)
with cols[col_idx]:
st.markdown(html, unsafe_allow_html=True)
def _render_key_numbers_panel(tax_result: dict):
"""Render the key numbers in the right sidebar panel as a styled HTML block."""
fields = [
("Tax Zone", "tax_zone", "{}"),
("Tax Owed", "tax_owed", "${:,.0f}"),
("Marginal Rate", "marginal_rate", "{:.1f}%"),
("Effective Rate", "effective_rate", "{:.1f}%"),
("Taxable SSB", "taxable_ssb", "${:,.0f}"),
("Take-Home", "take_home", "${:,.0f}"),
("Gross Income", "gross_income", "${:,.0f}"),
("Zero Point", "zero_point", "${:,.0f}"),
("Confluence Pt.", "confluence_point", "${:,.0f}"),
]
rows_html = []
for label, key, fmt in fields:
val = tax_result.get(key)
if val is not None:
# Use colored badge for Tax Zone
if key == "tax_zone":
display = _zone_badge_html(str(val), font_size=14)
else:
try:
display = fmt.format(val)
except (ValueError, TypeError):
display = str(val)
rows_html.append(
f'<div style="padding:6px 0;border-bottom:1px solid #e8eaf6;">'
f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:0.5px;">{label}</div>'
f'<div style="font-size:15px;font-weight:600;color:#1a237e;font-family:inherit;">{display}</div>'
f'</div>'
)
html = (
'<div style="background:#f0f4ff;border:1px solid #c5cae9;border-radius:10px;padding:12px 16px;">'
'<div style="font-size:15px;font-weight:700;color:#1a237e;margin-bottom:8px;">Your Key Numbers</div>'
+ "".join(rows_html) +
'</div>'
)
st.markdown(html, unsafe_allow_html=True)
def _send_button_message(msg: str):
"""Send a button-triggered message to the LLM and update session state.
Updates llm_summary with the response. If a scenario torpedo plot is
generated, replaces the main plot with it.
"""
orch: ChatOrchestrator = st.session_state.orchestrator
if not orch.is_available():
return
response_text, image_paths = orch.send_message(msg)
# Check for scenario torpedo results — replace main plot
for tr in orch.last_tool_results:
if tr["tool"] == "generate_scenario_torpedo_plot":
result = tr["result"]
st.session_state.scenario_old = result.get("old_key_numbers")
st.session_state.scenario_new = result.get("new_key_numbers")
st.session_state.all_new_key_numbers = result.get("all_new_key_numbers")
scenario_img = result.get("image_path", "")
if scenario_img and os.path.exists(scenario_img):
st.session_state.plot_path = scenario_img
# Only update the summary if the LLM actually returned text.
# If the LLM returned empty (e.g. it just called tools without commentary),
# keep the previous summary on screen rather than replacing it with blank.
if response_text.strip():
st.session_state.llm_summary = _escape_dollars(response_text)
# ---------------------------------------------------------------------------
# Session state initialization
# ---------------------------------------------------------------------------
def _init_session_state():
"""Initialize all session state keys with defaults."""
defaults = {
"orchestrator": ChatOrchestrator(),
"profile": {},
"tax_result": None,
"plot_path": None,
"llm_summary": None,
"analyzed": False,
"filing_status_display": None,
"show_whatif_input": False,
"scenario_old": None,
"scenario_new": None,
"all_new_key_numbers": None,
"llm_loading": False,
}
for key, val in defaults.items():
if key not in st.session_state:
st.session_state[key] = val
# ---------------------------------------------------------------------------
# Main app
# ---------------------------------------------------------------------------
def main():
st.set_page_config(
page_title="Tax Torpedo Analyzer",
page_icon="\U0001f4ca",
layout="wide",
)
# Inject custom CSS
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
_init_session_state()
# --- Header ---
st.markdown(
'<h1 style="text-align:center; color:#1a237e; margin-bottom:4px;">'
'Tax Torpedo Analyzer</h1>'
'<p style="text-align:center; color:#666; font-size:16px; margin-top:0;">'
'Understand how your Social Security benefits are taxed</p>',
unsafe_allow_html=True,
)
# --- Sidebar: Inputs ---
with st.sidebar:
st.header("Your Information")
filing_status_display = st.selectbox(
"Filing Status",
options=STATUS_CHOICES,
index=1, # default: MFJ
)
other_income = st.number_input(
"Total Other Taxable Income ($)",
min_value=0,
max_value=2_000_000,
value=40_000,
step=1_000,
)
ssb = st.number_input(
"Annual Social Security Benefit ($)",
min_value=0,
max_value=200_000,
value=31_810,
step=100,
)
analyze_clicked = st.button(
"Analyze My Situation",
type="primary",
use_container_width=True,
)
# --- Handle Analyze click ---
if analyze_clicked:
filing_status = STATUS_MAP.get(filing_status_display, "MFJ")
# Run tax calculation
tax_result = calculate_tax_situation(filing_status, float(ssb), float(other_income))
if "error" in tax_result:
st.error(f"Error: {tax_result['error']}")
return
# Generate torpedo plot
plot_result = generate_torpedo_plot(filing_status, float(ssb), float(other_income))
# Build profile
profile = UserProfile(
filing_status=filing_status,
ssb_annual=float(ssb),
other_income=float(other_income),
)
profile.tax_zone = tax_result.get("tax_zone")
profile.zero_point = tax_result.get("zero_point")
profile.confluence_point = tax_result.get("confluence_point")
profile.zero_point_oi = tax_result.get("zero_point_oi")
profile.confluence_point_oi = tax_result.get("confluence_point_oi")
profile.ref_taxable_income = tax_result.get("ref_taxable_income")
# Store in session state
st.session_state.tax_result = tax_result
st.session_state.plot_path = plot_result.get("image_path", "")
st.session_state.profile = profile.to_dict()
st.session_state.filing_status_display = filing_status_display
st.session_state.analyzed = True
st.session_state.llm_summary = None # Reset so we re-fetch
st.session_state.show_whatif_input = False
st.session_state.scenario_old = None
st.session_state.scenario_new = None
st.session_state.all_new_key_numbers = None
st.session_state.llm_loading = False
# --- Display results (if analyzed) ---
if st.session_state.analyzed and st.session_state.tax_result:
tax_result = st.session_state.tax_result
# Two-column layout: main content (left) + key numbers panel (right)
left_col, right_col = st.columns([3, 1])
with right_col:
all_new = st.session_state.get("all_new_key_numbers")
if all_new and len(all_new) > 1 and st.session_state.scenario_old:
_render_multi_scenario_panel(st.session_state.scenario_old, all_new)
elif st.session_state.scenario_old and st.session_state.scenario_new:
_render_scenario_panel(
st.session_state.scenario_old,
st.session_state.scenario_new,
)
else:
_render_key_numbers_panel(tax_result)
with left_col:
# --- Plot + Analysis area (replaced with spinner while loading) ---
if st.session_state.llm_loading:
with st.spinner("Analyzing your situation..."):
# Send the pending message
pending = st.session_state.get("_pending_msg", "")
if pending:
_send_button_message(pending)
st.session_state.pop("_pending_msg", None)
st.session_state.llm_loading = False
st.rerun()
else:
# Torpedo plot
plot_path = st.session_state.plot_path
if plot_path and os.path.exists(plot_path):
st.image(plot_path, use_container_width=True)
# LLM Summary text (directly under the plot)
if st.session_state.llm_summary is None:
orch: ChatOrchestrator = st.session_state.orchestrator
if orch.is_available():
with st.spinner("Generating AI summary..."):
profile = UserProfile.from_dict(st.session_state.profile)
orch.start_chat(profile.summary_text())
initial_msg = _build_initial_llm_message(
st.session_state.filing_status_display,
float(tax_result.get("ssb_annual", 0)),
float(tax_result.get("other_income", 0)),
tax_result,
)
summary, _ = orch.send_message(initial_msg)
if summary.strip():
st.session_state.llm_summary = _escape_dollars(summary)
else:
# LLM returned empty — use fallback
st.session_state.llm_summary = _fallback_summary(tax_result)
else:
st.session_state.llm_summary = _fallback_summary(
tax_result, orch.get_error() or ""
)
st.markdown(
f'<div class="analysis-text">{_md_to_html_simple(st.session_state.llm_summary)}</div>',
unsafe_allow_html=True,
)
# --- Action Buttons ---
st.markdown("---")
btn_cols = st.columns(3)
with btn_cols[0]:
understand_clicked = st.button(
"Help me understand this plot",
key="btn_understand",
use_container_width=True,
)
with btn_cols[1]:
improve_clicked = st.button(
"Improve my tax situation",
key="btn_improve",
use_container_width=True,
)
with btn_cols[2]:
whatif_clicked = st.button(
'Explore "what if" scenarios',
key="btn_whatif",
use_container_width=True,
)
# Handle button clicks — store message and trigger loading on rerun
if understand_clicked:
msg = (
"Please explain the torpedo chart shown above in detail. "
"Describe what the two lines (black dashed baseline and red total tax) "
"mean, and how they CONVERGE at the confluence point. "
"Explain that in the torpedo zone the red line rises steeply to catch "
"up with the baseline because each extra dollar of income also triggers "
"more SSB to become taxable. Explain what the three shaded zones represent "
"and where my red star marker is. "
"Explain what this means for my specific tax situation using my actual numbers. "
"Also explain what the bottom panel (marginal rate) shows."
)
st.session_state._pending_msg = msg
st.session_state.llm_loading = True
st.rerun()
if improve_clicked:
msg = _build_improve_message(tax_result)
st.session_state._pending_msg = msg
st.session_state.llm_loading = True
st.rerun()
if whatif_clicked:
st.session_state.show_whatif_input = True
st.rerun()
# Show what-if input if toggled
if st.session_state.show_whatif_input:
with st.form("whatif_form", clear_on_submit=True):
whatif_text = st.text_input(
"Describe your scenario:",
placeholder="e.g., What if I take $10,000 more from my IRA?",
)
submitted = st.form_submit_button("Run Scenario")
if submitted and whatif_text:
st.session_state._pending_msg = whatif_text
st.session_state.llm_loading = True
st.session_state.show_whatif_input = False
st.rerun()
# Equations breakdown (collapsed, below buttons)
with st.expander("Detailed Calculation Breakdown", expanded=False):
st.markdown(_format_equations_markdown(tax_result))
else:
# --- Welcome / Landing Page ---
st.markdown("<br>", unsafe_allow_html=True)
# Center the logo using columns
spacer_l, logo_col, spacer_r = st.columns([1, 1, 1])
with logo_col:
st.image("Logo.png", use_container_width=True)
# Caption
st.markdown(
'<p style="text-align:center; color:#555; font-size:18px; '
'max-width:600px; margin:12px auto 0 auto; line-height:1.6;">'
'Let me help protect you from the '
'<b style="color:#c62828;">tax torpedo</b>.<br>'
'Enter your information in the sidebar and click '
'<b>Analyze My Situation</b> to see how your '
'Social Security benefits are really being taxed.'
'</p>',
unsafe_allow_html=True,
)
# --- Disclaimer ---
st.markdown(
'<div class="disclaimer">'
"Educational tool only. Uses 2016 federal tax rates. "
"Does not include state taxes. Consult a qualified tax professional."
"</div>",
unsafe_allow_html=True,
)
def _md_to_html_simple(md_text: str) -> str:
"""Minimal markdown-to-HTML for the analysis text.
Converts **bold**, *italic*, and newlines to HTML so we can render
inside a styled <div> without Streamlit's default markdown quirks.
"""
if not md_text:
return ""
import html as html_mod
text = html_mod.escape(md_text)
# Bold: **text**
text = re.sub(r'\*\*(.+?)\*\*', r'<strong>\1</strong>', text)
# Italic: *text*
text = re.sub(r'\*(.+?)\*', r'<em>\1</em>', text)
# Double newlines -> paragraph breaks
text = text.replace('\n\n', '</p><p>')
# Single newlines -> line breaks
text = text.replace('\n', '<br>')
return f'<p>{text}</p>'
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if check_password():
main()