import plotly.graph_objects as go import plotly.io as pio import numpy as np """ Stacked bar chart: GPU memory breakdown vs sequence length, with menus for Model Size and Recomputation. Responsive, no zoom/pan, clean hover; styled to match the minimal theme. """ # Axes seq_labels = ["1024", "2048", "4096", "8192"] seq_scale = np.array([1, 2, 4, 8], dtype=float) # Components and colors (aligned with the provided example) components = [ ("parameters", "rgb(78, 165, 183)"), ("gradients", "rgb(227, 138, 66)"), ("optimizer", "rgb(232, 137, 171)"), ("activations", "rgb(206, 192, 250)"), ] # Model sizes and base memory (GB) for params/grad/opt (constant vs seq), by size model_sizes = ["1B", "3B", "8B", "70B", "405B"] params_mem = { "1B": 4.0, "3B": 13.3, "8B": 26.0, "70B": 244.0, "405B": 1520.0, } # Optimizer ~= 2x params; gradients ~= params (illustrative) # Activations base coefficient per size (growth ~ coeff * (seq/1024)^2) act_coeff = { "1B": 3.6, "3B": 9.3, "8B": 46.2, "70B": 145.7, "405B": 1519.9, } def activations_curve(size_key: str, recompute: str) -> np.ndarray: base = act_coeff[size_key] * (seq_scale ** 2) if recompute == "selective": return base * 0.25 if recompute == "full": return base * (1.0/16.0) return base def stack_for(size_key: str, recompute: str): p = np.full_like(seq_scale, params_mem[size_key], dtype=float) g = np.full_like(seq_scale, params_mem[size_key], dtype=float) o = np.full_like(seq_scale, 2.0 * params_mem[size_key], dtype=float) a = activations_curve(size_key, recompute) return { "parameters": p, "gradients": g, "optimizer": o, "activations": a, } # Precompute all combinations recomp_modes = ["none", "selective", "full"] Y = {mode: {size: stack_for(size, mode) for size in model_sizes} for mode in recomp_modes} # Build traces: 4 traces per size (20 total). Start with size index 0 visible fig = go.Figure() for size in model_sizes: for comp_name, color in components: fig.add_bar( x=seq_labels, y=Y["none"][size][comp_name], name=comp_name, marker=dict(color=color), hovertemplate="Seq len=%{x}
Mem=%{y:.1f}GB
%{data.name}", showlegend=True, visible=(size == model_sizes[0]), ) # Compute y-axis ranges per size and recomputation def max_total(size: str, mode: str) -> float: stacks = Y[mode][size] totals = stacks["parameters"] + stacks["gradients"] + stacks["optimizer"] + stacks["activations"] return float(np.max(totals)) layout_y_ranges = {mode: {size: 1.05 * max_total(size, mode) for size in model_sizes} for mode in recomp_modes} # Layout fig.update_layout( barmode="stack", autosize=True, paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", margin=dict(l=40, r=28, t=20, b=40), hovermode="x unified", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), xaxis=dict(title=dict(text="Sequence Length"), fixedrange=True), yaxis=dict(title=dict(text="Memory (GB)"), fixedrange=True), ) # Updatemenus: Model Size (toggle visibility) buttons_sizes = [] for i, size in enumerate(model_sizes): visible = [False] * (len(model_sizes) * len(components)) start = i * len(components) for j in range(len(components)): visible[start + j] = True buttons_sizes.append(dict( label=size, method="update", args=[ {"visible": visible}, {"yaxis": {"range": [0, layout_y_ranges["none"][size]]}}, ], )) # Updatemenus: Recomputation (restyle y across all traces) def y_for_mode(mode: str): ys = [] for size in model_sizes: stacks = Y[mode][size] for comp_name, _ in components: ys.append(stacks[comp_name]) return ys buttons_recomp = [] for mode, label in [("none", "None"), ("selective", "selective"), ("full", "full")]: ys = y_for_mode(mode) # Flatten into the format expected by Plotly for multiple traces buttons_recomp.append(dict( label=label, method="update", args=[ {"y": ys}, {"yaxis": {"range": [0, max(layout_y_ranges[mode].values())]}}, ], )) fig.update_layout( updatemenus=[ dict( type="dropdown", x=1.03, xanchor="left", y=0.60, yanchor="top", showactive=True, active=0, buttons=buttons_sizes, ), dict( type="dropdown", x=1.03, xanchor="left", y=0.40, yanchor="top", showactive=True, active=0, buttons=buttons_recomp, ), ], annotations=[ dict(text="Model Size:", x=1.03, xanchor="left", xref="paper", y=0.60, yanchor="bottom", yref="paper", showarrow=False), dict(text="Recomputation:", x=1.03, xanchor="left", xref="paper", y=0.40, yanchor="bottom", yref="paper", showarrow=False), ], ) # Write fragment fig.write_html("./plotly-bar.html", include_plotlyjs=False, full_html=False, config={ 'displayModeBar': False, 'responsive': True, 'scrollZoom': False, })