Bobs_Tax_Project / plot_generators.py
rmd826's picture
Upload 9 files
a54f28a verified
# -*- coding: utf-8 -*-
"""
Plot generators for the Tax Torpedo Analyzer.
All functions save figures to PNG files and return the file path.
No plt.show() calls -- designed for headless use.
Uses the analyst's "reference taxable income" x-axis convention:
x_plot = OI - Std. Ded. + 0.85 * SSB
Elderly-friendly styling: large fonts, high contrast, clear annotations.
"""
from __future__ import annotations
import os
import tempfile
from typing import Dict, List, Optional, Tuple
import numpy as np
import matplotlib
matplotlib.use("Agg") # headless backend
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from tax_engine import (
CONFIGS, TaxConfig,
ssb_tax, bracket_tax, compute_baseline_tax, tax_with_ssb, tax_with_ssb_detail,
bracket_marginal_rate, total_marginal_rate, find_torpedo_bounds,
classify_zone,
)
# ---------------------------------------------------------------------------
# Global plot styling (elderly-friendly)
# ---------------------------------------------------------------------------
PLOT_STYLE = {
"font.size": 14,
"axes.titlesize": 18,
"axes.labelsize": 16,
"xtick.labelsize": 13,
"ytick.labelsize": 13,
"legend.fontsize": 12,
"figure.dpi": 200,
"figure.facecolor": "white",
"axes.facecolor": "white",
"axes.grid": True,
"grid.alpha": 0.3,
}
ZONE_COLORS = {
"No-Tax Zone": ("#c8e6c9", "green"), # light green bg, green text
"High-Tax Zone": ("#ffcdd2", "#c62828"), # light red bg, red text
"Same-Old Zone": ("#bbdefb", "#1565c0"), # light blue bg, blue text
}
# Darker zone label colors for text annotations above plot
_ZONE_LABEL_COLORS = {
"no_tax": "#1b5e20", # dark green
"high_tax": "#b71c1c", # dark red
"same_old": "#4a148c", # dark purple
}
X_AXIS_LABEL = "Reference Income: Other Income + 85% of SSB ($)"
# Colors for multiple scenario positions
_SCENARIO_COLORS = ["#1565c0", "#ff6f00", "#2e7d32", "#6a1b9a"]
_SCENARIO_LABELS = ["Scenario A", "Scenario B", "Scenario C", "Scenario D"]
def _save_fig(fig, prefix: str = "plot") -> str:
"""Save figure to a temp PNG and return the path."""
fd, path = tempfile.mkstemp(suffix=".png", prefix=f"tax_{prefix}_")
os.close(fd)
fig.savefig(path, dpi=300, bbox_inches="tight", facecolor="white")
plt.close(fig)
return path
def _dollar_fmt(x, _=None):
"""Format axis ticks as $XX,XXX."""
return f"${x:,.0f}"
def _pct_fmt(x, _=None):
"""Format axis ticks as XX%."""
return f"{x:.0f}%"
def _add_zone_shading(ax, x_plot, ts_plot, te_plot):
"""Add zone shading to an axis WITHOUT legend labels."""
if ts_plot is not None:
ax.axvspan(x_plot[0], ts_plot, color="green", alpha=0.08)
if ts_plot is not None and te_plot is not None:
ax.axvspan(ts_plot, te_plot, color="red", alpha=0.08)
if te_plot is not None:
ax.axvspan(te_plot, x_plot[-1], color="purple", alpha=0.06)
def _add_zone_text_labels(ax, x_plot, ts_plot, te_plot):
"""Add zone text labels above the plot in darker zone colors."""
ylim = ax.get_ylim()
label_y = ylim[1] # at the top of the visible area
if ts_plot is not None:
mid = (max(x_plot[0], 0) + ts_plot) / 2
ax.text(mid, label_y, "No-Tax Zone", ha="center", va="bottom",
fontsize=12, fontweight="bold", color=_ZONE_LABEL_COLORS["no_tax"],
clip_on=False)
if ts_plot is not None and te_plot is not None:
mid = (ts_plot + te_plot) / 2
ax.text(mid, label_y, "High-Tax Zone", ha="center", va="bottom",
fontsize=12, fontweight="bold", color=_ZONE_LABEL_COLORS["high_tax"],
clip_on=False)
if te_plot is not None:
mid = (te_plot + x_plot[-1]) / 2
ax.text(mid, label_y, "Same-Old Zone", ha="center", va="bottom",
fontsize=12, fontweight="bold", color=_ZONE_LABEL_COLORS["same_old"],
clip_on=False)
# Expand y-axis slightly to make room for text labels
ax.set_ylim(ylim[0], ylim[1] * 1.10)
def _add_legend_below(fig, axes, extra_handles=None, extra_labels=None, ncol=None):
"""Collect legend handles from all axes and place a single row below the charts.
*extra_handles* / *extra_labels* are appended to the collected items.
"""
handles, labels = [], []
seen = set()
for ax in (axes if hasattr(axes, '__iter__') else [axes]):
for h, l in zip(*ax.get_legend_handles_labels()):
if l not in seen:
handles.append(h)
labels.append(l)
seen.add(l)
# Remove any per-axis legend
legend = ax.get_legend()
if legend:
legend.remove()
if extra_handles and extra_labels:
for h, l in zip(extra_handles, extra_labels):
if l not in seen:
handles.append(h)
labels.append(l)
seen.add(l)
if not handles:
return
if ncol is None:
ncol = min(5, len(handles))
fig.legend(
handles, labels,
loc="lower center",
ncol=min(ncol, len(handles)),
fontsize=11,
frameon=True,
fancybox=True,
shadow=False,
borderpad=0.6,
columnspacing=1.5,
)
# Make room at the bottom for the legend (extra space for two rows)
fig.subplots_adjust(bottom=0.16)
def _compute_key_numbers(other_income, ssb, cfg, ts_plot, te_plot,
torpedo_start, torpedo_end, delta=100.0):
"""Compute key numbers for a given income position on the analyst axis."""
my_tax = tax_with_ssb(other_income, ssb, cfg)
my_gross = other_income + ssb
my_take_home = my_gross - my_tax
my_marginal = 100.0 * total_marginal_rate(other_income, ssb, cfg, delta=delta)
my_taxable_ssb = ssb_tax(other_income, ssb, cfg)
my_x_plot = other_income - cfg.standard_deduction + 0.85 * ssb
my_eff = (100.0 * my_tax / other_income) if other_income > 0 else 0.0
zone = classify_zone(other_income, ssb, cfg, torpedo_start, torpedo_end)
return {
"tax_owed": round(my_tax, 2),
"taxable_ssb": round(my_taxable_ssb, 2),
"marginal_rate": round(my_marginal, 2),
"effective_rate": round(my_eff, 2),
"zero_point": round(ts_plot, 0) if ts_plot is not None else None,
"confluence_point": round(te_plot, 0) if te_plot is not None else None,
"zero_point_oi": round(torpedo_start, 0) if torpedo_start is not None else None,
"confluence_point_oi": round(torpedo_end, 0) if torpedo_end is not None else None,
"zone": zone,
"gross_income": round(my_gross, 2),
"take_home": round(my_take_home, 2),
"taxable_income": round(max(0.0, my_x_plot), 2),
"other_income": other_income,
"filing_status": cfg.name,
"ssb": ssb,
}
def _knee_sensitivity_lines(
ssb: float, cfg: "TaxConfig", x_max: float, ssb_step: float = 5000.0
):
"""
Trace the locus of the two SSB knee points as SSB varies.
Starts at the user's SSB and steps upward by ssb_step until the
knee's total-tax value reaches zero (i.e. the line lands in the
no-tax zone). The x-axis follows the analyst convention
x_plot = OI - std_ded + 0.85*SSB.
Knee 1: provisional income = t1 (0% -> 50% taxable SSB)
Knee 2: provisional income = t2 (50% -> 85% taxable SSB)
Returns
-------
k1_x, k1_y_tax, k1_y_mr – knee-1 x positions, total-tax y, marginal-rate y
k2_x, k2_y_tax, k2_y_mr – same for knee 2
"""
t1, t2 = cfg.ssb_thresholds.t1, cfg.ssb_thresholds.t2
def _trace(t_thresh):
xs, ys_tax, ys_mr = [], [], []
ssb_k = 0
while ssb_k <= ssb + 1_000_000:
oi = t_thresh - 0.5 * ssb_k
xp = oi - cfg.standard_deduction + 0.85 * ssb_k
if 0.0 <= xp <= x_max:
y_tax = tax_with_ssb(max(0.0, oi), ssb_k, cfg)
y_mr = 100.0 * total_marginal_rate(max(0.0, oi), ssb_k, cfg)
xs.append(xp)
ys_tax.append(y_tax)
ys_mr.append(y_mr)
if y_tax <= 0:
break # past user's SSB and tax has hit zero – stop
ssb_k += ssb_step
return xs, ys_tax, ys_mr
k1_x, k1_yt, k1_ym = _trace(t1)
k2_x, k2_yt, k2_ym = _trace(t2)
return k1_x, k1_yt, k1_ym, k2_x, k2_yt, k2_ym
# ---------------------------------------------------------------------------
# Plot 1: Torpedo Overview (2-panel: total tax + marginal rate)
# Uses analyst x-axis: OI - Std. Ded. + 0.85*SSB
# ---------------------------------------------------------------------------
def generate_torpedo_plot(
filing_status: str,
ssb: float,
other_income: float,
x_max: Optional[float] = None,
n: int = 800,
delta: float = 100.0,
) -> Dict:
"""
Main torpedo visualization. 2-panel figure:
Top: Total tax owed vs reference taxable income
Bottom: Marginal rate vs reference taxable income
X-axis: OI - Std. Ded. + 0.85*SSB ("reference taxable income")
Baseline (black dashed): bracket_tax(x_plot) -- brackets alone
Total (red solid): actual IRS tax with SSB torpedo
Returns dict with 'image_path' and 'key_numbers'.
"""
cfg = CONFIGS[filing_status]
if x_max is None:
x_max = max(other_income * 1.5, 100000)
with plt.rc_context(PLOT_STYLE):
# --- Analyst x-axis convention ---
x_start = cfg.standard_deduction - 0.85 * ssb
x = np.linspace(x_start, x_max, n)
x_plot = x - cfg.standard_deduction + 0.85 * ssb # analyst axis, starts at 0
x_clipped = np.maximum(0.0, x) # OI can't be negative
# Total curve: actual OI (clipped) drives tax calculations
tax_total = np.array([tax_with_ssb(xi, ssb, cfg) for xi in x_clipped], dtype=float)
mr_total = np.array([100.0 * total_marginal_rate(xi, ssb, cfg, delta=delta)
for xi in x_clipped], dtype=float)
taxable_ssb_arr = np.array([ssb_tax(xi, ssb, cfg) for xi in x_clipped], dtype=float)
# Baseline curve: bracket_tax(oa) -- "what would brackets alone give?"
tax_base = np.array([bracket_tax(max(0.0, oa), cfg) for oa in x_plot], dtype=float)
mr_base = np.array([100.0 * bracket_marginal_rate(oa + cfg.standard_deduction, cfg)
for oa in x_plot], dtype=float)
# User's point on analyst axis
my_tax = tax_with_ssb(other_income, ssb, cfg)
my_gross = other_income + ssb
my_take_home = my_gross - my_tax
my_marginal = 100.0 * total_marginal_rate(other_income, ssb, cfg, delta=delta)
my_taxable_ssb = ssb_tax(other_income, ssb, cfg)
my_x_plot = other_income - cfg.standard_deduction + 0.85 * ssb
my_eff = (100.0 * my_tax / other_income) if other_income > 0 else 0.0
# Zone boundaries (raw OI values from find_torpedo_bounds)
torpedo_start, torpedo_end = find_torpedo_bounds(cfg, ssb, x_max)
zone = classify_zone(other_income, ssb, cfg, torpedo_start, torpedo_end)
# Transform boundaries to analyst x-axis
ts_plot = (torpedo_start - cfg.standard_deduction + 0.85 * ssb) if torpedo_start is not None else None
te_plot = (torpedo_end - cfg.standard_deduction + 0.85 * ssb) if torpedo_end is not None else None
# Knee sensitivity lines (green): locus of knee points as SSB varies
k1_x, k1_yt, k1_ym, k2_x, k2_yt, k2_ym = _knee_sensitivity_lines(ssb, cfg, x_max)
key_numbers = _compute_key_numbers(
other_income, ssb, cfg, ts_plot, te_plot,
torpedo_start, torpedo_end, delta,
)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
# === TOP PANEL: Total Taxes Owed vs Analyst X-Axis ===
# Zone shading (no legend labels)
_add_zone_shading(ax1, x_plot, ts_plot, te_plot)
# Baseline tax line -- BLACK DASHED
ax1.plot(x_plot, tax_base, color="black", linewidth=2, linestyle="--",
label="Baseline Tax (no SSB)")
# Total tax line -- RED SOLID
ax1.plot(x_plot, tax_total, color="#e53935", linewidth=2,
label="Total Tax (with SSB)")
# User marker
ax1.scatter(my_x_plot, my_tax, marker="*", s=500, color="red",
edgecolors="white", zorder=4, label="Your Tax Owed")
# Zone boundary markers
if torpedo_start is not None and ts_plot is not None:
tax_at_zp = tax_total[np.argmin(np.abs(x - torpedo_start))]
ax1.scatter(ts_plot, tax_at_zp, marker="o", color="green",
s=120, zorder=3, label="Zero Point")
if torpedo_end is not None and te_plot is not None:
tax_at_cp = tax_total[np.argmin(np.abs(x - torpedo_end))]
ax1.scatter(te_plot, tax_at_cp, marker="D", color="orange",
s=100, zorder=3, label="Confluence Point (85% cap)")
# Green knee-locus lines: how the knee point moves as SSB changes
if len(k1_x) > 1:
ax1.plot(k1_x, k1_yt, color="green", linewidth=1.8, zorder=5,
linestyle="--", label="Knee locus: 0%\u219250% taxable SSB")
if len(k2_x) > 1:
ax1.plot(k2_x, k2_yt, color="green", linewidth=1.8, zorder=5,
linestyle="--", label="Knee locus: 50%\u219285% taxable SSB")
ax1.set_xlabel(X_AXIS_LABEL)
ax1.set_ylabel("Total Tax Owed ($)")
ax1.set_title(f"{cfg.name}: Total Taxes Owed (SSB = ${ssb:,.0f})")
ax1.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax1.yaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
# Zone text labels above plot
_add_zone_text_labels(ax1, x_plot, ts_plot, te_plot)
# === BOTTOM PANEL: Marginal Rate vs Analyst X-Axis ===
# Zone shading on bottom panel too
_add_zone_shading(ax2, x_plot, ts_plot, te_plot)
# Baseline marginal rate -- BLACK DASHED step
ax2.step(x_plot, mr_base, where="post", color="black", linewidth=1.5,
linestyle="--", label="Baseline Marginal Rate (no SSB)")
# Total marginal rate -- RED SOLID
ax2.plot(x_plot, mr_total, color="#e53935", linewidth=2,
label="Marginal Rate (with SSB)")
# User marker
ax2.scatter(my_x_plot, my_marginal, marker="*", s=500, color="red",
edgecolors="white", zorder=3, label="Your Position")
ax2.set_xlabel(X_AXIS_LABEL)
ax2.set_ylabel("Marginal Tax Rate (%)")
ax2.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax2.yaxis.set_major_formatter(mticker.FuncFormatter(_pct_fmt))
ax2.set_ylim(0, max(mr_total) * 1.05 if max(mr_total) > 0 else 50)
# Zone text labels above bottom panel
_add_zone_text_labels(ax2, x_plot, ts_plot, te_plot)
# Taxable SSB overlay on right axis
ax2b = ax2.twinx()
ax2b.plot(x_plot, taxable_ssb_arr, linestyle="--", alpha=0.25, color="gray",
label="Taxable SSB ($)")
ax2b.set_ylabel("Taxable SSB ($)", fontsize=12, alpha=0.5)
ax2b.yaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
# Collect twin-axis handles for legend
twin_handles, twin_labels = ax2b.get_legend_handles_labels()
ax2b_legend = ax2b.get_legend()
if ax2b_legend:
ax2b_legend.remove()
# Single legend row below charts
_add_legend_below(fig, [ax1, ax2], extra_handles=twin_handles, extra_labels=twin_labels)
plt.tight_layout()
fig.subplots_adjust(bottom=0.16)
path = _save_fig(fig, "torpedo")
return {"image_path": path, "key_numbers": key_numbers}
# ---------------------------------------------------------------------------
# Plot 1B: Scenario Comparison ON the Torpedo Curve
# Now supports MULTIPLE new positions (shown in different colors).
# ---------------------------------------------------------------------------
def generate_scenario_torpedo_plot(
filing_status: str,
ssb: float,
old_other_income: float,
new_other_incomes: list | float,
scenario_labels: list | None = None,
x_max: Optional[float] = None,
n: int = 800,
delta: float = 100.0,
) -> Dict:
"""
Scenario comparison overlaid on the torpedo curve.
Shows OLD position (red star) and one or more NEW positions (colored
squares) with arrows connecting them and delta annotations.
*new_other_incomes* can be a single float or a list of floats.
*scenario_labels* optional list of labels for each new position.
Returns dict with 'image_path', 'old_key_numbers', 'new_key_numbers'.
When multiple new positions, 'new_key_numbers' is a list.
"""
# Normalise inputs
if isinstance(new_other_incomes, (int, float)):
new_other_incomes = [float(new_other_incomes)]
else:
new_other_incomes = [float(v) for v in new_other_incomes]
if scenario_labels is None:
if len(new_other_incomes) == 1:
scenario_labels = ["New Scenario"]
else:
scenario_labels = [f"Scenario {chr(65+i)}: OI=${v:,.0f}"
for i, v in enumerate(new_other_incomes)]
cfg = CONFIGS[filing_status]
if x_max is None:
all_oi = [old_other_income] + list(new_other_incomes)
x_max = max(max(all_oi) * 1.5, 100000)
with plt.rc_context(PLOT_STYLE):
# --- Analyst x-axis convention ---
x_start = cfg.standard_deduction - 0.85 * ssb
x = np.linspace(x_start, x_max, n)
x_plot = x - cfg.standard_deduction + 0.85 * ssb
x_clipped = np.maximum(0.0, x)
# Curves
tax_total = np.array([tax_with_ssb(xi, ssb, cfg) for xi in x_clipped], dtype=float)
mr_total = np.array([100.0 * total_marginal_rate(xi, ssb, cfg, delta=delta)
for xi in x_clipped], dtype=float)
tax_base = np.array([bracket_tax(max(0.0, oa), cfg) for oa in x_plot], dtype=float)
mr_base = np.array([100.0 * bracket_marginal_rate(oa + cfg.standard_deduction, cfg)
for oa in x_plot], dtype=float)
taxable_ssb_arr = np.array([ssb_tax(xi, ssb, cfg) for xi in x_clipped], dtype=float)
# Zone boundaries
torpedo_start, torpedo_end = find_torpedo_bounds(cfg, ssb, x_max)
ts_plot = (torpedo_start - cfg.standard_deduction + 0.85 * ssb) if torpedo_start is not None else None
te_plot = (torpedo_end - cfg.standard_deduction + 0.85 * ssb) if torpedo_end is not None else None
# Knee sensitivity lines (green)
k1_x, k1_yt, k1_ym, k2_x, k2_yt, k2_ym = _knee_sensitivity_lines(ssb, cfg, x_max)
# Key numbers for old position
old_kn = _compute_key_numbers(
old_other_income, ssb, cfg, ts_plot, te_plot,
torpedo_start, torpedo_end, delta,
)
# Key numbers for each new position
new_kns = []
for i, noi in enumerate(new_other_incomes):
kn = _compute_key_numbers(
noi, ssb, cfg, ts_plot, te_plot,
torpedo_start, torpedo_end, delta,
)
kn["scenario_label"] = scenario_labels[i]
kn["scenario_color"] = _SCENARIO_COLORS[i % len(_SCENARIO_COLORS)]
new_kns.append(kn)
# Analyst-axis positions
old_x_plot = old_other_income - cfg.standard_deduction + 0.85 * ssb
old_tax = tax_with_ssb(old_other_income, ssb, cfg)
old_mr = 100.0 * total_marginal_rate(old_other_income, ssb, cfg, delta=delta)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
# === TOP PANEL ===
_add_zone_shading(ax1, x_plot, ts_plot, te_plot)
ax1.plot(x_plot, tax_base, color="black", linewidth=2, linestyle="--",
label="Baseline Tax (no SSB)")
ax1.plot(x_plot, tax_total, color="#e53935", linewidth=2,
label="Total Tax (with SSB)")
# Old position (red star)
ax1.scatter(old_x_plot, old_tax, marker="*", s=500, color="red",
edgecolors="white", zorder=4, label="Current Position")
# Zone boundary markers
if ts_plot is not None:
tax_at_zp = tax_total[np.argmin(np.abs(x - torpedo_start))]
ax1.scatter(ts_plot, tax_at_zp, marker="o", color="green",
s=120, zorder=3, label="Zero Point")
if te_plot is not None:
tax_at_cp = tax_total[np.argmin(np.abs(x - torpedo_end))]
ax1.scatter(te_plot, tax_at_cp, marker="D", color="orange",
s=100, zorder=3, label="Confluence Point")
# New positions (colored squares)
for i, (noi, lbl) in enumerate(zip(new_other_incomes, scenario_labels)):
color = _SCENARIO_COLORS[i % len(_SCENARIO_COLORS)]
new_x = noi - cfg.standard_deduction + 0.85 * ssb
new_tax_val = tax_with_ssb(noi, ssb, cfg)
ax1.scatter(new_x, new_tax_val, marker="s", s=300, color=color,
edgecolors="white", zorder=4, label=lbl)
# Arrow from old to new
ax1.annotate("", xy=(new_x, new_tax_val), xytext=(old_x_plot, old_tax),
arrowprops=dict(arrowstyle="-|>", color=color, lw=2.5))
# Delta annotation
delta_tax = new_tax_val - old_tax
sign = "+" if delta_tax >= 0 else ""
mid_x = (old_x_plot + new_x) / 2
mid_y = (old_tax + new_tax_val) / 2
delta_color = "#c62828" if delta_tax > 0 else "#2e7d32"
# Offset labels vertically when there are multiple scenarios
offset = 0
if len(new_other_incomes) > 1:
offset = (i - (len(new_other_incomes) - 1) / 2) * (max(tax_total) * 0.06)
ax1.text(mid_x, mid_y + offset, f"{sign}${delta_tax:,.0f} tax",
fontsize=13, fontweight="bold", color=delta_color,
ha="center", va="bottom",
bbox=dict(facecolor="white", alpha=0.85, edgecolor=color,
boxstyle="round,pad=0.3"))
# Green knee-locus lines (top panel)
if len(k1_x) > 1:
ax1.plot(k1_x, k1_yt, color="green", linewidth=1.8, zorder=5,
linestyle="--",label="Knee locus: 0%\u219250% taxable SSB")
if len(k2_x) > 1:
ax1.plot(k2_x, k2_yt, color="green", linewidth=1.8, zorder=5,
linestyle="--", label="Knee locus: 50%\u219285% taxable SSB")
ax1.set_xlabel(X_AXIS_LABEL)
ax1.set_ylabel("Total Tax Owed ($)")
ax1.set_title(f"{cfg.name}: Scenario Comparison (SSB = ${ssb:,.0f})")
ax1.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax1.yaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
_add_zone_text_labels(ax1, x_plot, ts_plot, te_plot)
# === BOTTOM PANEL ===
_add_zone_shading(ax2, x_plot, ts_plot, te_plot)
ax2.step(x_plot, mr_base, where="post", color="black", linewidth=1.5,
linestyle="--", label="Baseline Marginal Rate (no SSB)")
ax2.plot(x_plot, mr_total, color="#e53935", linewidth=2,
label="Marginal Rate (with SSB)")
# Old position (red star)
ax2.scatter(old_x_plot, old_mr, marker="*", s=500, color="red",
edgecolors="white", zorder=3, label="Current Position")
# New positions (colored squares)
for i, (noi, lbl) in enumerate(zip(new_other_incomes, scenario_labels)):
color = _SCENARIO_COLORS[i % len(_SCENARIO_COLORS)]
new_x = noi - cfg.standard_deduction + 0.85 * ssb
new_mr_val = 100.0 * total_marginal_rate(noi, ssb, cfg, delta=delta)
ax2.scatter(new_x, new_mr_val, marker="s", s=300, color=color,
edgecolors="white", zorder=3, label=lbl)
# Arrow
ax2.annotate("", xy=(new_x, new_mr_val), xytext=(old_x_plot, old_mr),
arrowprops=dict(arrowstyle="-|>", color=color, lw=2.5))
# Delta annotation
delta_mr_val = new_mr_val - old_mr
sign_mr = "+" if delta_mr_val >= 0 else ""
mid_x_mr = (old_x_plot + new_x) / 2
mid_y_mr = (old_mr + new_mr_val) / 2
delta_mr_color = "#c62828" if delta_mr_val > 0 else "#2e7d32"
offset = 0
if len(new_other_incomes) > 1:
offset = (i - (len(new_other_incomes) - 1) / 2) * 3
ax2.text(mid_x_mr, mid_y_mr + offset, f"{sign_mr}{delta_mr_val:.1f}% rate",
fontsize=13, fontweight="bold", color=delta_mr_color,
ha="center", va="bottom",
bbox=dict(facecolor="white", alpha=0.85, edgecolor=color,
boxstyle="round,pad=0.3"))
ax2.set_xlabel(X_AXIS_LABEL)
ax2.set_ylabel("Marginal Tax Rate (%)")
ax2.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax2.yaxis.set_major_formatter(mticker.FuncFormatter(_pct_fmt))
ax2.set_ylim(0, max(mr_total) * 1.05 if max(mr_total) > 0 else 50)
_add_zone_text_labels(ax2, x_plot, ts_plot, te_plot)
# Taxable SSB overlay
ax2b = ax2.twinx()
ax2b.plot(x_plot, taxable_ssb_arr, linestyle="--", alpha=0.25, color="gray",
label="Taxable SSB ($)")
ax2b.set_ylabel("Taxable SSB ($)", fontsize=12, alpha=0.5)
ax2b.yaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
twin_handles, twin_labels = ax2b.get_legend_handles_labels()
ax2b_legend = ax2b.get_legend()
if ax2b_legend:
ax2b_legend.remove()
# Single legend row below charts
_add_legend_below(fig, [ax1, ax2], extra_handles=twin_handles, extra_labels=twin_labels)
plt.tight_layout()
fig.subplots_adjust(bottom=0.16)
path = _save_fig(fig, "scenario_torpedo")
# Return format: if single scenario, new_key_numbers is a dict;
# if multiple, it's a list. Also include 'all_new_key_numbers' always as list.
result = {
"image_path": path,
"old_key_numbers": old_kn,
"all_new_key_numbers": new_kns,
}
if len(new_kns) == 1:
result["new_key_numbers"] = new_kns[0]
else:
result["new_key_numbers"] = new_kns[0] # first scenario for panel display
return result
# ---------------------------------------------------------------------------
# Plot 2: Scenario Comparison (grouped bar chart)
# ---------------------------------------------------------------------------
def generate_scenario_comparison(
filing_status: str,
ssb: float,
scenarios: List[Dict],
) -> Dict:
"""
Compare 2-4 income scenarios side by side.
scenarios: list of dicts with 'label' and 'other_income' keys.
Returns dict with 'scenario_results' and 'image_path'.
"""
cfg = CONFIGS[filing_status]
results = []
for sc in scenarios:
oi = sc["other_income"]
detail = tax_with_ssb_detail(oi, ssb, cfg)
baseline = compute_baseline_tax(oi, cfg)
ssb_driven = detail["tax"] - baseline
mr = total_marginal_rate(oi, ssb, cfg)
zp, cp = find_torpedo_bounds(cfg, ssb)
zone = classify_zone(oi, ssb, cfg, zp, cp)
results.append({
"label": sc["label"],
"other_income": oi,
"gross_income": oi + ssb,
"tax_owed": round(detail["tax"], 2),
"regular_tax": round(max(0, baseline), 2),
"ssb_driven_tax": round(max(0, ssb_driven), 2),
"take_home": round(oi + ssb - detail["tax"], 2),
"marginal_rate": round(mr * 100, 2),
"effective_rate": round(detail["effective_rate"], 2),
"zone": zone,
})
with plt.rc_context(PLOT_STYLE):
labels = [r["label"] for r in results]
take_homes = [r["take_home"] for r in results]
reg_taxes = [r["regular_tax"] for r in results]
ssb_taxes = [r["ssb_driven_tax"] for r in results]
x_pos = np.arange(len(labels))
width = 0.55
fig, ax = plt.subplots(figsize=(max(10, len(labels) * 3), 7))
bars_take = ax.bar(x_pos, take_homes, width, label="Take-Home Income",
color="#4CAF50", edgecolor="white")
bars_reg = ax.bar(x_pos, reg_taxes, width, bottom=take_homes,
label="Regular Taxes", color="#c3e3f7", edgecolor="white")
bottoms = [t + r for t, r in zip(take_homes, reg_taxes)]
bars_ssb = ax.bar(x_pos, ssb_taxes, width, bottom=bottoms,
label="SSB-Driven Taxes", color="#f7dfc3", edgecolor="white")
# Annotate bars
for i, r in enumerate(results):
# Take-home amount
ax.text(i, r["take_home"] / 2, f"${r['take_home']:,.0f}",
ha="center", va="center", fontsize=13, fontweight="bold", color="white")
# Total tax
total_tax = r["regular_tax"] + r["ssb_driven_tax"]
if total_tax > 0:
ax.text(i, r["take_home"] + total_tax / 2, f"Tax: ${total_tax:,.0f}",
ha="center", va="center", fontsize=11, color="#333")
# Zone badge at top
zone_color = ZONE_COLORS.get(r["zone"], ("#eee", "#333"))
ax.text(i, r["gross_income"] + r["gross_income"] * 0.02,
r["zone"], ha="center", va="bottom", fontsize=11,
fontweight="bold", color=zone_color[1],
bbox=dict(boxstyle="round,pad=0.3", facecolor=zone_color[0], alpha=0.8))
ax.set_ylabel("Dollars ($)")
ax.set_title(f"Scenario Comparison (SSB = ${ssb:,.0f})")
ax.set_xticks(x_pos)
ax.set_xticklabels(labels, fontsize=14)
ax.yaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
_add_legend_below(fig, [ax])
plt.tight_layout()
fig.subplots_adjust(bottom=0.16)
path = _save_fig(fig, "scenarios")
return {"scenario_results": results, "image_path": path}
# ---------------------------------------------------------------------------
# Plot 3: Educational Concept Diagrams
# ---------------------------------------------------------------------------
def generate_concept_diagram(
concept: str,
user_ssb: Optional[float] = None,
user_income: Optional[float] = None,
filing_status: str = "MFJ",
) -> Dict:
"""
Generate a dynamic educational diagram using the user's actual numbers.
Returns dict with 'explanation_text' and 'diagram_path'.
"""
cfg = CONFIGS.get(filing_status, CONFIGS["MFJ"])
ssb = user_ssb or 30000
income = user_income or 40000
if concept == "tax_torpedo":
return _diagram_tax_torpedo(cfg, ssb, income)
elif concept == "provisional_income":
return _diagram_provisional_income(cfg, ssb, income)
elif concept == "roth_conversion":
return _diagram_roth_conversion(cfg, ssb, income)
elif concept == "rmd":
return _diagram_rmd()
elif concept == "marginal_vs_effective_rate":
return _diagram_marginal_vs_effective(cfg, ssb, income)
elif concept == "tax_zones":
return _diagram_tax_zones(cfg, ssb, income)
elif concept == "ssb_taxation_rules":
return _diagram_ssb_rules(cfg, ssb)
else:
return {"explanation_text": f"Unknown concept: {concept}", "diagram_path": ""}
def _diagram_tax_torpedo(cfg, ssb, income):
"""Simplified marginal rate chart with big 'TAX TORPEDO' annotation."""
with plt.rc_context(PLOT_STYLE):
x = np.linspace(0, max(income * 2, 100000), 600)
mr_base = np.array([100.0 * bracket_marginal_rate(xi, cfg) for xi in x])
mr_total = np.array([100.0 * total_marginal_rate(xi, ssb, cfg) for xi in x])
fig, ax = plt.subplots(figsize=(14, 7))
ax.step(x, mr_base, where="post", color="#90CAF9", linewidth=2,
label="Normal Tax Rate")
ax.plot(x, mr_total, color="#e53935", linewidth=3,
label="Your Actual Tax Rate (with SS)")
ax.fill_between(x, mr_base, mr_total,
where=(mr_total > mr_base + 1),
alpha=0.3, color="#e53935")
# Find torpedo peak for annotation
peak_idx = np.argmax(mr_total)
peak_x = x[peak_idx]
peak_y = mr_total[peak_idx]
ax.annotate(
"THE TAX TORPEDO\nYour rate spikes here!",
xy=(peak_x, peak_y),
xytext=(peak_x + (x[-1] - x[0]) * 0.15, peak_y + 5),
fontsize=18, fontweight="bold", color="#c62828",
arrowprops=dict(arrowstyle="->", color="#c62828", lw=3),
bbox=dict(boxstyle="round,pad=0.5", facecolor="#ffcdd2", alpha=0.9),
)
ax.set_xlabel("Other Income ($)")
ax.set_ylabel("Marginal Tax Rate (%)")
ax.set_title("The Tax Torpedo: Hidden Tax Rate Spike on Social Security")
ax.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax.yaxis.set_major_formatter(mticker.FuncFormatter(_pct_fmt))
ax.legend(fontsize=14)
plt.tight_layout()
path = _save_fig(fig, "concept_torpedo")
return {
"explanation_text": (
"The 'Tax Torpedo' is a hidden tax rate spike that hits people "
"receiving Social Security. As your other income rises, more of "
"your Social Security becomes taxable -- on top of the normal tax "
"on that income. This can push your real tax rate much higher than "
"the bracket you're officially in."
),
"diagram_path": path,
}
def _diagram_provisional_income(cfg, ssb, income):
"""Flow diagram showing how provisional income is calculated."""
pi = income + 0.5 * ssb
t1, t2 = cfg.ssb_thresholds.t1, cfg.ssb_thresholds.t2
with plt.rc_context(PLOT_STYLE):
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_xlim(0, 10)
ax.set_ylim(0, 6)
ax.axis("off")
# Boxes
boxes = [
(1, 4.5, f"Your Other Income\n${income:,.0f}", "#90CAF9"),
(5, 4.5, f"50% of Social Security\n${0.5 * ssb:,.0f}", "#81C784"),
(3, 2.5, f"Provisional Income\n${pi:,.0f}", "#FFE082"),
]
for bx, by, text, color in boxes:
ax.add_patch(plt.Rectangle((bx - 1.2, by - 0.6), 2.4, 1.2,
facecolor=color, edgecolor="#333", linewidth=2, zorder=2,
transform=ax.transData))
ax.text(bx, by, text, ha="center", va="center", fontsize=14,
fontweight="bold", zorder=3)
# Arrows
ax.annotate("", xy=(2.5, 3.1), xytext=(1.5, 3.9),
arrowprops=dict(arrowstyle="->", lw=2.5, color="#333"))
ax.annotate("", xy=(3.5, 3.1), xytext=(5, 3.9),
arrowprops=dict(arrowstyle="->", lw=2.5, color="#333"))
ax.text(3, 3.5, "+", fontsize=24, fontweight="bold", ha="center", va="center")
# Threshold info
if pi <= t1:
result_text = f"PI (${pi:,.0f}) is below ${t1:,.0f}\n0% of SS is taxable"
result_color = "#c8e6c9"
elif pi <= t2:
result_text = f"PI (${pi:,.0f}) is between ${t1:,.0f} and ${t2:,.0f}\nUp to 50% of SS is taxable"
result_color = "#fff9c4"
else:
result_text = f"PI (${pi:,.0f}) is above ${t2:,.0f}\nUp to 85% of SS is taxable"
result_color = "#ffcdd2"
ax.add_patch(plt.Rectangle((1, 0.2), 6, 1.0,
facecolor=result_color, edgecolor="#333", linewidth=2, zorder=2))
ax.text(4, 0.7, result_text, ha="center", va="center", fontsize=14,
fontweight="bold", zorder=3)
ax.set_title("How Provisional Income Determines Your SSB Taxation", fontsize=18, pad=20)
plt.tight_layout()
path = _save_fig(fig, "concept_pi")
return {
"explanation_text": (
f"Provisional Income = Your Other Income + half of your Social Security. "
f"Yours is ${income:,.0f} + ${0.5*ssb:,.0f} = ${pi:,.0f}. "
f"The IRS uses this number to decide how much of your Social Security is taxable."
),
"diagram_path": path,
}
def _diagram_roth_conversion(cfg, ssb, income):
"""Visual showing Roth conversion impact."""
with plt.rc_context(PLOT_STYLE):
conversions = [0, 5000, 10000, 20000, 30000, 50000]
taxes = [tax_with_ssb(income + c, ssb, cfg) for c in conversions]
base_tax = taxes[0]
costs = [t - base_tax for t in taxes]
fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar([f"${c:,.0f}" for c in conversions], costs,
color=["#4CAF50" if c < 2000 else "#FFB74D" if c < 5000 else "#ef5350"
for c in costs],
edgecolor="white", linewidth=2)
for bar, cost in zip(bars, costs):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 50,
f"${cost:,.0f}", ha="center", fontsize=13, fontweight="bold")
ax.set_xlabel("Roth Conversion Amount")
ax.set_ylabel("Additional Tax Cost ($)")
ax.set_title("Tax Cost of Different Roth Conversion Amounts")
ax.yaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
plt.tight_layout()
path = _save_fig(fig, "concept_roth")
return {
"explanation_text": (
"A Roth conversion moves money from a Traditional IRA (taxed when withdrawn) "
"to a Roth IRA (tax-free when withdrawn). You pay tax now on the converted amount, "
"but never again. The key is finding how much you can convert before hitting "
"the expensive tax torpedo zone."
),
"diagram_path": path,
}
def _diagram_rmd():
"""Simple RMD explanation."""
with plt.rc_context(PLOT_STYLE):
ages = list(range(73, 96))
from rmd_tables import UNIFORM_LIFETIME_TABLE
periods = [UNIFORM_LIFETIME_TABLE.get(a, 2.0) for a in ages]
pcts = [100.0 / p for p in periods]
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(ages, pcts, color="#42A5F5", edgecolor="white")
for i, (age, pct) in enumerate(zip(ages, pcts)):
if i % 3 == 0:
ax.text(age, pct + 0.1, f"{pct:.1f}%", ha="center", fontsize=10)
ax.set_xlabel("Age")
ax.set_ylabel("RMD as % of Balance")
ax.set_title("Required Minimum Distributions: Percentage Increases with Age")
ax.yaxis.set_major_formatter(mticker.FuncFormatter(_pct_fmt))
plt.tight_layout()
path = _save_fig(fig, "concept_rmd")
return {
"explanation_text": (
"Starting at age 73, the IRS requires you to withdraw a minimum amount "
"from your Traditional IRA, 401(k), and 403(b) each year. This is called "
"a Required Minimum Distribution (RMD). The percentage you must withdraw "
"increases as you age -- starting around 3.6% at 73 and rising each year."
),
"diagram_path": path,
}
def _diagram_marginal_vs_effective(cfg, ssb, income):
"""Show difference between marginal and effective rate."""
with plt.rc_context(PLOT_STYLE):
x = np.linspace(0, max(income * 2, 100000), 500)
marginals = np.array([100.0 * total_marginal_rate(xi, ssb, cfg) for xi in x])
effectives = np.array([
100.0 * tax_with_ssb(xi, ssb, cfg) / xi if xi > 0 else 0 for xi in x
])
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(x, marginals, color="#e53935", linewidth=2.5, label="Marginal Rate (next dollar)")
ax.plot(x, effectives, color="#1565c0", linewidth=2.5, label="Effective Rate (overall average)")
my_marginal = 100.0 * total_marginal_rate(income, ssb, cfg)
my_effective = 100.0 * tax_with_ssb(income, ssb, cfg) / income if income > 0 else 0
ax.scatter([income], [my_marginal], s=200, color="#e53935", zorder=5, edgecolors="white")
ax.scatter([income], [my_effective], s=200, color="#1565c0", zorder=5, edgecolors="white")
ax.annotate(f"Your marginal: {my_marginal:.1f}%",
xy=(income, my_marginal), xytext=(income + income * 0.1, my_marginal + 3),
fontsize=13, arrowprops=dict(arrowstyle="->", color="#e53935"),
color="#e53935", fontweight="bold")
ax.annotate(f"Your effective: {my_effective:.1f}%",
xy=(income, my_effective), xytext=(income + income * 0.1, my_effective - 5),
fontsize=13, arrowprops=dict(arrowstyle="->", color="#1565c0"),
color="#1565c0", fontweight="bold")
ax.set_xlabel("Other Income ($)")
ax.set_ylabel("Tax Rate (%)")
ax.set_title("Marginal vs. Effective Tax Rate")
ax.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax.yaxis.set_major_formatter(mticker.FuncFormatter(_pct_fmt))
ax.legend(fontsize=14)
plt.tight_layout()
path = _save_fig(fig, "concept_rates")
return {
"explanation_text": (
f"Your MARGINAL rate ({my_marginal:.1f}%) is the tax on your next dollar of income. "
f"Your EFFECTIVE rate ({my_effective:.1f}%) is the average rate on all your income. "
"The marginal rate matters most for decisions about withdrawals and conversions."
),
"diagram_path": path,
}
def _diagram_tax_zones(cfg, ssb, income):
"""Annotated zone diagram."""
x_max = max(income * 2, 100000)
zp, cp = find_torpedo_bounds(cfg, ssb, x_max)
with plt.rc_context(PLOT_STYLE):
x = np.linspace(0, x_max, 600)
mr = np.array([100.0 * total_marginal_rate(xi, ssb, cfg) for xi in x])
fig, ax = plt.subplots(figsize=(14, 7))
if zp is not None:
ax.axvspan(0, zp, color="green", alpha=0.15)
ax.text(zp / 2, max(mr) * 0.85, "NO-TAX\nZONE",
ha="center", fontsize=20, fontweight="bold", color="#2e7d32",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
if zp is not None and cp is not None:
ax.axvspan(zp, cp, color="red", alpha=0.12)
ax.text((zp + cp) / 2, max(mr) * 0.85, "HIGH-TAX\nZONE\n(Tax Torpedo!)",
ha="center", fontsize=20, fontweight="bold", color="#c62828",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
if cp is not None:
ax.axvspan(cp, x_max, color="blue", alpha=0.08)
ax.text((cp + x_max) / 2, max(mr) * 0.85, "SAME-OLD\nZONE",
ha="center", fontsize=20, fontweight="bold", color="#1565c0",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
ax.plot(x, mr, color="#333", linewidth=2.5)
# Mark user
my_mr = 100.0 * total_marginal_rate(income, ssb, cfg)
ax.scatter([income], [my_mr], s=400, color="red", marker="*",
edgecolors="white", zorder=5)
ax.annotate("YOU ARE HERE", xy=(income, my_mr),
xytext=(income, my_mr + 5), fontsize=16, fontweight="bold",
color="red", ha="center")
ax.set_xlabel("Other Income ($)")
ax.set_ylabel("Marginal Tax Rate (%)")
ax.set_title("The Three Tax Zones")
ax.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax.yaxis.set_major_formatter(mticker.FuncFormatter(_pct_fmt))
plt.tight_layout()
path = _save_fig(fig, "concept_zones")
return {
"explanation_text": (
"There are three tax zones: "
"The GREEN No-Tax Zone (your income is low enough that you owe no federal tax). "
"The RED High-Tax Zone (the 'torpedo' -- your Social Security is being taxed "
"at accelerated rates). "
"The BLUE Same-Old Zone (past the torpedo -- normal tax rates apply)."
),
"diagram_path": path,
}
def _diagram_ssb_rules(cfg, ssb):
"""Visual showing the 3-tier SSB taxation rules."""
t1 = cfg.ssb_thresholds.t1
t2 = cfg.ssb_thresholds.t2
with plt.rc_context(PLOT_STYLE):
x = np.linspace(0, t2 * 2.5, 500)
pi = x + 0.5 * ssb
taxable = np.array([ssb_tax(xi, ssb, cfg) for xi in x])
pct_taxable = taxable / ssb * 100
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(x, pct_taxable, color="#e53935", linewidth=3)
# Annotate tiers
ax.axvline(t1 - 0.5 * ssb, linestyle="--", color="green", linewidth=2)
ax.axvline(t2 - 0.5 * ssb, linestyle="--", color="orange", linewidth=2)
ax.axhline(50, linestyle=":", color="gray", alpha=0.5)
ax.axhline(85, linestyle=":", color="gray", alpha=0.5)
ax.text(0, 2, "Tier 1: 0% Taxable", fontsize=14, color="#2e7d32", fontweight="bold")
tier2_x = max(0, t1 - 0.5 * ssb)
ax.text(tier2_x + 1000, 30, "Tier 2: Up to 50%", fontsize=14,
color="#F57F17", fontweight="bold")
tier3_x = max(0, t2 - 0.5 * ssb)
ax.text(tier3_x + 1000, 70, "Tier 3: Up to 85%", fontsize=14,
color="#c62828", fontweight="bold")
ax.set_xlabel("Other Income ($)")
ax.set_ylabel("% of Social Security That Is Taxable")
ax.set_title(f"Social Security Taxation Rules ({cfg.name})")
ax.xaxis.set_major_formatter(mticker.FuncFormatter(_dollar_fmt))
ax.yaxis.set_major_formatter(mticker.FuncFormatter(_pct_fmt))
ax.set_ylim(-5, 100)
plt.tight_layout()
path = _save_fig(fig, "concept_ssb_rules")
return {
"explanation_text": (
"The IRS taxes your Social Security in three tiers based on your "
"'Provisional Income' (other income + half of SS): "
f"Below ${t1:,.0f}: 0% taxable. "
f"${t1:,.0f} to ${t2:,.0f}: up to 50% taxable. "
f"Above ${t2:,.0f}: up to 85% taxable. "
"The maximum is 85% -- the IRS never taxes more than 85% of your SS."
),
"diagram_path": path,
}