|
|
import plotly.graph_objects as go |
|
|
import plotly.io as pio |
|
|
import numpy as np |
|
|
|
|
|
""" |
|
|
Stacked bar chart: GPU memory breakdown vs sequence length, with menus for Model Size and Recomputation. |
|
|
Responsive, no zoom/pan, clean hover; styled to match the minimal theme. |
|
|
""" |
|
|
|
|
|
|
|
|
seq_labels = ["1024", "2048", "4096", "8192"] |
|
|
seq_scale = np.array([1, 2, 4, 8], dtype=float) |
|
|
|
|
|
|
|
|
components = [ |
|
|
("parameters", "rgb(78, 165, 183)"), |
|
|
("gradients", "rgb(227, 138, 66)"), |
|
|
("optimizer", "rgb(232, 137, 171)"), |
|
|
("activations", "rgb(206, 192, 250)"), |
|
|
] |
|
|
|
|
|
|
|
|
model_sizes = ["1B", "3B", "8B", "70B", "405B"] |
|
|
params_mem = { |
|
|
"1B": 4.0, |
|
|
"3B": 13.3, |
|
|
"8B": 26.0, |
|
|
"70B": 244.0, |
|
|
"405B": 1520.0, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
act_coeff = { |
|
|
"1B": 3.6, |
|
|
"3B": 9.3, |
|
|
"8B": 46.2, |
|
|
"70B": 145.7, |
|
|
"405B": 1519.9, |
|
|
} |
|
|
|
|
|
def activations_curve(size_key: str, recompute: str) -> np.ndarray: |
|
|
base = act_coeff[size_key] * (seq_scale ** 2) |
|
|
if recompute == "selective": |
|
|
return base * 0.25 |
|
|
if recompute == "full": |
|
|
return base * (1.0/16.0) |
|
|
return base |
|
|
|
|
|
def stack_for(size_key: str, recompute: str): |
|
|
p = np.full_like(seq_scale, params_mem[size_key], dtype=float) |
|
|
g = np.full_like(seq_scale, params_mem[size_key], dtype=float) |
|
|
o = np.full_like(seq_scale, 2.0 * params_mem[size_key], dtype=float) |
|
|
a = activations_curve(size_key, recompute) |
|
|
return { |
|
|
"parameters": p, |
|
|
"gradients": g, |
|
|
"optimizer": o, |
|
|
"activations": a, |
|
|
} |
|
|
|
|
|
|
|
|
recomp_modes = ["none", "selective", "full"] |
|
|
Y = {mode: {size: stack_for(size, mode) for size in model_sizes} for mode in recomp_modes} |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
for size in model_sizes: |
|
|
for comp_name, color in components: |
|
|
fig.add_bar( |
|
|
x=seq_labels, |
|
|
y=Y["none"][size][comp_name], |
|
|
name=comp_name, |
|
|
marker=dict(color=color), |
|
|
hovertemplate="Seq len=%{x}<br>Mem=%{y:.1f}GB<br>%{data.name}<extra></extra>", |
|
|
showlegend=True, |
|
|
visible=(size == model_sizes[0]), |
|
|
) |
|
|
|
|
|
|
|
|
def max_total(size: str, mode: str) -> float: |
|
|
stacks = Y[mode][size] |
|
|
totals = stacks["parameters"] + stacks["gradients"] + stacks["optimizer"] + stacks["activations"] |
|
|
return float(np.max(totals)) |
|
|
|
|
|
layout_y_ranges = {mode: {size: 1.05 * max_total(size, mode) for size in model_sizes} for mode in recomp_modes} |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
barmode="stack", |
|
|
autosize=True, |
|
|
paper_bgcolor="rgba(0,0,0,0)", |
|
|
plot_bgcolor="rgba(0,0,0,0)", |
|
|
margin=dict(l=40, r=28, t=20, b=40), |
|
|
hovermode="x unified", |
|
|
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), |
|
|
xaxis=dict(title=dict(text="Sequence Length"), fixedrange=True), |
|
|
yaxis=dict(title=dict(text="Memory (GB)"), fixedrange=True), |
|
|
) |
|
|
|
|
|
|
|
|
buttons_sizes = [] |
|
|
for i, size in enumerate(model_sizes): |
|
|
visible = [False] * (len(model_sizes) * len(components)) |
|
|
start = i * len(components) |
|
|
for j in range(len(components)): |
|
|
visible[start + j] = True |
|
|
buttons_sizes.append(dict( |
|
|
label=size, |
|
|
method="update", |
|
|
args=[ |
|
|
{"visible": visible}, |
|
|
{"yaxis": {"range": [0, layout_y_ranges["none"][size]]}}, |
|
|
], |
|
|
)) |
|
|
|
|
|
|
|
|
def y_for_mode(mode: str): |
|
|
ys = [] |
|
|
for size in model_sizes: |
|
|
stacks = Y[mode][size] |
|
|
for comp_name, _ in components: |
|
|
ys.append(stacks[comp_name]) |
|
|
return ys |
|
|
|
|
|
buttons_recomp = [] |
|
|
for mode, label in [("none", "None"), ("selective", "selective"), ("full", "full")]: |
|
|
ys = y_for_mode(mode) |
|
|
|
|
|
buttons_recomp.append(dict( |
|
|
label=label, |
|
|
method="update", |
|
|
args=[ |
|
|
{"y": ys}, |
|
|
{"yaxis": {"range": [0, max(layout_y_ranges[mode].values())]}}, |
|
|
], |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
updatemenus=[ |
|
|
dict( |
|
|
type="dropdown", |
|
|
x=1.03, xanchor="left", |
|
|
y=0.60, yanchor="top", |
|
|
showactive=True, |
|
|
active=0, |
|
|
buttons=buttons_sizes, |
|
|
), |
|
|
dict( |
|
|
type="dropdown", |
|
|
x=1.03, xanchor="left", |
|
|
y=0.40, yanchor="top", |
|
|
showactive=True, |
|
|
active=0, |
|
|
buttons=buttons_recomp, |
|
|
), |
|
|
], |
|
|
annotations=[ |
|
|
dict(text="Model Size:", x=1.03, xanchor="left", xref="paper", y=0.60, yanchor="bottom", yref="paper", showarrow=False), |
|
|
dict(text="Recomputation:", x=1.03, xanchor="left", xref="paper", y=0.40, yanchor="bottom", yref="paper", showarrow=False), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
fig.write_html("./plotly-bar.html", |
|
|
include_plotlyjs=False, |
|
|
full_html=False, |
|
|
config={ |
|
|
'displayModeBar': False, |
|
|
'responsive': True, |
|
|
'scrollZoom': False, |
|
|
}) |
|
|
|
|
|
|