| | import gradio as gr |
| | import matplotlib |
| | matplotlib.use('Agg') |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from matplotlib.ticker import MultipleLocator |
| |
|
| | HARM_INTRO = """ |
| | The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost. |
| | This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size). |
| | """ |
| |
|
| | |
| | A100_flops = 312e12 |
| | H100_flops = 990e12 |
| |
|
| | |
| | E = 1.62 |
| | A = 406.4 |
| | B = 410.7 |
| | alpha = 0.336 |
| | beta = 0.283 |
| |
|
| | Bn = 10**9 |
| |
|
| | G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) |
| |
|
| | |
| | def to_flops(N, D): |
| | return 6 * N * D |
| |
|
| | def n_opt(C): |
| | return G * ((C/6) ** (beta / (alpha+beta))) |
| |
|
| | def d_opt(C): |
| | return (1/G) * ((C/6) ** (alpha / (alpha+beta))) |
| |
|
| | def compute_kd(kn): |
| | frac = (A/B)*(G**(-alpha-beta)) |
| | kd = (1-((kn**-alpha -1)*frac))**(1/(-beta)) |
| | return kd |
| |
|
| | def compute_overhead(kn, kd): |
| | return kn*kd - 1 |
| |
|
| | |
| | kn_min = 0.18 |
| | kn_max = 2 |
| |
|
| | kns = np.linspace(kn_min, kn_max, 100) |
| | overheads = [] |
| | for kn in kns: |
| | kd = compute_kd(kn) |
| | overheads.append(compute_overhead(kn, kd)*100) |
| |
|
| | def plot_curve(kn, kd): |
| | fig, ax = plt.subplots(dpi=200, figsize=(5, 3)) |
| | plt.plot(kns, overheads, color="black", zorder=1) |
| | plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2) |
| | plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2) |
| | plt.xlabel("Fraction of Chinchilla optimal model size") |
| | plt.ylabel("Compute overhead (%)") |
| | plt.legend(loc="best") |
| | plt.grid(True, which="both") |
| | plt.grid(True, which="minor", alpha=0.5) |
| | ax.yaxis.set_minor_locator(MultipleLocator(10)) |
| | plt.tight_layout() |
| |
|
| | return fig |
| |
|
| |
|
| | def compute(N, D, gpu_type, gpu_util, n_gpus, gpu_price): |
| | |
| | C = to_flops(N * Bn, D * Bn) |
| | N_opt = n_opt(C) |
| | D_opt = d_opt(C) |
| |
|
| | kn = Bn*N/N_opt |
| | kd = compute_kd(kn) |
| | |
| | fig = plot_curve(kn, kd) |
| |
|
| | |
| | gpu_util = gpu_util/100 |
| | if gpu_type=="H100": |
| | gpu_flops = H100_flops * gpu_util |
| | else: |
| | gpu_flops = A100_flops * gpu_util |
| | gpu_hours = (C / (gpu_flops * 3600)) |
| |
|
| |
|
| | text = f"""\ |
| | ## Training summary |
| | |
| | |Training compute| Training cost | Training time | Total GPU hours | |
| | |:----|:-------|:-------|:-------| |
| | |{C:.2E} TFLOPs | ${(gpu_hours * gpu_price)/1e6:.2f}M | {gpu_hours/(24*n_gpus):.2f} days | {gpu_hours/1_000_000:.2f}M | |
| | |
| | ## Chinchilla and Training/Inference Trade-off |
| | Optimal model/dataset size for training compute and how it translates to training overhead and inference savings according to Harm's law |
| | |Chinchilla optimal model | Chinchilla optimal dataset | Training overhead | Inference savings| |
| | |:----|:-------|:----|:-------| |
| | | {N_opt/Bn:.2f}B parameters | {D_opt/Bn:.2f}B tokens | {100*compute_overhead(kn, kd):.2f}%| {100 - kn*100:.2f}% | |
| | """ |
| |
|
| | return text, fig |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# Train LLMs") |
| | |
| | gr.Markdown("## Training configuration") |
| | with gr.Row(): |
| | |
| | N = gr.Number(value=7, label="Model size (in B parameters):") |
| | D = gr.Number(value=2000, label="Dataset size (in B tokens):") |
| | |
| | gr.Markdown("## Cluster configuration") |
| | with gr.Row(): |
| | n_gpus = gr.Number(value=1000, label="Number of GPUs") |
| | gpu_type = gr.Dropdown(choices=["A100", "H100"], value="H100", label="GPU type") |
| | gpu_util = gr.Number(value=50, label="% GPU utilization") |
| | gpu_price = gr.Number(value=3.00, label="$/GPU/Hour") |
| | button = gr.Button("Compute!") |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("## Harm's law") |
| | plot = gr.Plot(value=plt) |
| | gr.Markdown(HARM_INTRO) |
| | |
| | with gr.Column(): |
| | md = gr.Markdown("") |
| |
|
| | button.click(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot]) |
| | demo.load(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot]) |
| | demo.launch() |