| | import plotly.graph_objects as go |
| | import plotly.io as pio |
| | import numpy as np |
| |
|
| | """ |
| | Stacked bar chart: GPU memory breakdown vs sequence length, with menus for Model Size and Recomputation. |
| | Responsive, no zoom/pan, clean hover; styled to match the minimal theme. |
| | """ |
| |
|
| | |
| | seq_labels = ["1024", "2048", "4096", "8192"] |
| | seq_scale = np.array([1, 2, 4, 8], dtype=float) |
| |
|
| | |
| | components = [ |
| | ("parameters", "rgb(78, 165, 183)"), |
| | ("gradients", "rgb(227, 138, 66)"), |
| | ("optimizer", "rgb(232, 137, 171)"), |
| | ("activations", "rgb(206, 192, 250)"), |
| | ] |
| |
|
| | |
| | model_sizes = ["1B", "3B", "8B", "70B", "405B"] |
| | params_mem = { |
| | "1B": 4.0, |
| | "3B": 13.3, |
| | "8B": 26.0, |
| | "70B": 244.0, |
| | "405B": 1520.0, |
| | } |
| | |
| |
|
| | |
| | act_coeff = { |
| | "1B": 3.6, |
| | "3B": 9.3, |
| | "8B": 46.2, |
| | "70B": 145.7, |
| | "405B": 1519.9, |
| | } |
| |
|
| | def activations_curve(size_key: str, recompute: str) -> np.ndarray: |
| | base = act_coeff[size_key] * (seq_scale ** 2) |
| | if recompute == "selective": |
| | return base * 0.25 |
| | if recompute == "full": |
| | return base * (1.0/16.0) |
| | return base |
| |
|
| | def stack_for(size_key: str, recompute: str): |
| | p = np.full_like(seq_scale, params_mem[size_key], dtype=float) |
| | g = np.full_like(seq_scale, params_mem[size_key], dtype=float) |
| | o = np.full_like(seq_scale, 2.0 * params_mem[size_key], dtype=float) |
| | a = activations_curve(size_key, recompute) |
| | return { |
| | "parameters": p, |
| | "gradients": g, |
| | "optimizer": o, |
| | "activations": a, |
| | } |
| |
|
| | |
| | recomp_modes = ["none", "selective", "full"] |
| | Y = {mode: {size: stack_for(size, mode) for size in model_sizes} for mode in recomp_modes} |
| |
|
| | |
| | fig = go.Figure() |
| | for size in model_sizes: |
| | for comp_name, color in components: |
| | fig.add_bar( |
| | x=seq_labels, |
| | y=Y["none"][size][comp_name], |
| | name=comp_name, |
| | marker=dict(color=color), |
| | hovertemplate="Seq len=%{x}<br>Mem=%{y:.1f}GB<br>%{data.name}<extra></extra>", |
| | showlegend=True, |
| | visible=(size == model_sizes[0]), |
| | ) |
| |
|
| | |
| | def max_total(size: str, mode: str) -> float: |
| | stacks = Y[mode][size] |
| | totals = stacks["parameters"] + stacks["gradients"] + stacks["optimizer"] + stacks["activations"] |
| | return float(np.max(totals)) |
| |
|
| | layout_y_ranges = {mode: {size: 1.05 * max_total(size, mode) for size in model_sizes} for mode in recomp_modes} |
| |
|
| | |
| | fig.update_layout( |
| | barmode="stack", |
| | autosize=True, |
| | paper_bgcolor="rgba(0,0,0,0)", |
| | plot_bgcolor="rgba(0,0,0,0)", |
| | margin=dict(l=40, r=28, t=20, b=40), |
| | hovermode="x unified", |
| | legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), |
| | xaxis=dict(title=dict(text="Sequence Length"), fixedrange=True), |
| | yaxis=dict(title=dict(text="Memory (GB)"), fixedrange=True), |
| | ) |
| |
|
| | |
| | buttons_sizes = [] |
| | for i, size in enumerate(model_sizes): |
| | visible = [False] * (len(model_sizes) * len(components)) |
| | start = i * len(components) |
| | for j in range(len(components)): |
| | visible[start + j] = True |
| | buttons_sizes.append(dict( |
| | label=size, |
| | method="update", |
| | args=[ |
| | {"visible": visible}, |
| | {"yaxis": {"range": [0, layout_y_ranges["none"][size]]}}, |
| | ], |
| | )) |
| |
|
| | |
| | def y_for_mode(mode: str): |
| | ys = [] |
| | for size in model_sizes: |
| | stacks = Y[mode][size] |
| | for comp_name, _ in components: |
| | ys.append(stacks[comp_name]) |
| | return ys |
| |
|
| | buttons_recomp = [] |
| | for mode, label in [("none", "None"), ("selective", "selective"), ("full", "full")]: |
| | ys = y_for_mode(mode) |
| | |
| | buttons_recomp.append(dict( |
| | label=label, |
| | method="update", |
| | args=[ |
| | {"y": ys}, |
| | {"yaxis": {"range": [0, max(layout_y_ranges[mode].values())]}}, |
| | ], |
| | )) |
| |
|
| | fig.update_layout( |
| | updatemenus=[ |
| | dict( |
| | type="dropdown", |
| | x=1.03, xanchor="left", |
| | y=0.60, yanchor="top", |
| | showactive=True, |
| | active=0, |
| | buttons=buttons_sizes, |
| | ), |
| | dict( |
| | type="dropdown", |
| | x=1.03, xanchor="left", |
| | y=0.40, yanchor="top", |
| | showactive=True, |
| | active=0, |
| | buttons=buttons_recomp, |
| | ), |
| | ], |
| | annotations=[ |
| | dict(text="Model Size:", x=1.03, xanchor="left", xref="paper", y=0.60, yanchor="bottom", yref="paper", showarrow=False), |
| | dict(text="Recomputation:", x=1.03, xanchor="left", xref="paper", y=0.40, yanchor="bottom", yref="paper", showarrow=False), |
| | ], |
| | ) |
| |
|
| | |
| | fig.write_html("./plotly-bar.html", |
| | include_plotlyjs=False, |
| | full_html=False, |
| | config={ |
| | 'displayModeBar': False, |
| | 'responsive': True, |
| | 'scrollZoom': False, |
| | }) |
| |
|
| |
|