Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from enum import Enum | |
| from throughput_utils import create_throughput_plot | |
| class AttentionType(Enum): | |
| LOCAL = 0 | |
| GLOBAL = 1 | |
| class PhoneBandwidth(Enum): | |
| GPU_iPhone_16 = 60 | |
| GPU_iPhone_15 = 51.2 | |
| GPU_iPhone_14 = 34.1 | |
| ANE_iPad_M4_16GB = 64.34 # includes LMHEAD | |
| ANE_iPhone_15 = 47.18 # includes LMHEAD | |
| ANE_iPhone_12 = 32.09 # includes LMHEAD | |
| ANE_M1 = 60.87 | |
| ANE_M1_Pro = 54.90 | |
| ANE_M1_Max = 54.62 | |
| ANE_M1_Ultra = 54.72 | |
| ANE_M2 = 60.45 | |
| ANE_M2_Max = 62.01 | |
| ANE_M2_Ultra = 61.68 | |
| ANE_M3_Max = 120.22 | |
| ANE_M4_16GB = 64.18 | |
| ANE_M4_Pro = 126.36 | |
| ANE_M4_Max = 118.88 | |
| custom_css = """ | |
| #plot-container { | |
| border-radius: 10px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 1px 3px rgba(0, 0, 0, 0.08); | |
| padding: 1rem; | |
| background-color: white; | |
| height: 100%; | |
| margin-bottom: 1.5rem; | |
| } | |
| #generate-button { | |
| background-color: #2563eb; | |
| color: white; | |
| border-radius: 8px; | |
| font-weight: bold; | |
| padding: 10px 20px; | |
| box-shadow: 0 4px 6px rgba(37, 99, 235, 0.1); | |
| transition: all 0.2s ease; | |
| width: 100%; | |
| max-width: 400px; | |
| margin: 0 auto; | |
| font-size: 16px; | |
| } | |
| #generate-button:hover { | |
| background-color: #1d4ed8; | |
| box-shadow: 0 6px 8px rgba(37, 99, 235, 0.2); | |
| transform: translateY(-2px); | |
| } | |
| .gradio-container { | |
| background-color: #f5f7fa; | |
| } | |
| /* Custom styles for sliders containers */ | |
| .sliders-container { | |
| border: 1px solid rgba(0, 0, 0, 0.1); | |
| border-radius: 8px; | |
| padding: 1rem; | |
| margin-top: 0.5rem; | |
| background-color: rgba(255, 255, 255, 0.8); | |
| } | |
| #error-status { | |
| color: #b91c1c; | |
| background-color: #fee2e2; | |
| border-radius: 8px; | |
| padding: 0.75rem; | |
| margin-top: 0.5rem; | |
| border: 1px solid #f87171; | |
| font-weight: 500; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gqa_sliders = [] | |
| mla_sliders = [] | |
| with gr.Column(): | |
| gr.Markdown( | |
| """# ๐ On-Device LLM Throughput Calculator | |
| This tool estimates the throughput (tokens per second) of Large Language Models on devices with memory bandwidth constraints. | |
| It visualizes how different attention mechanisms (GQA, MLA) and context lengths affect throughput. | |
| """ | |
| ) | |
| with gr.Row(): | |
| plot_output = gr.Image(label="Throughput Plot", type="pil", elem_id="plot-container") | |
| # Add status element to display validation errors | |
| status_output = gr.Markdown(visible=False, elem_id="error-status") | |
| with gr.Row(): | |
| plot_button = gr.Button("Generate Throughput Plot", size="lg", elem_id="generate-button", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Device Configuration") | |
| model_name = gr.Textbox(label="Model Name", value="TinyLLM") | |
| iphone_model = gr.Dropdown( | |
| label="iPhone Model", | |
| choices=[e.name for e in PhoneBandwidth], | |
| value=PhoneBandwidth.GPU_iPhone_16.name, | |
| interactive=True | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### Attention Configurations to Plot") | |
| gr.Markdown("#### GQA Head Configurations") | |
| gr.Markdown("*Note: GQA head count must be less than or equal to the total number of heads*") | |
| with gr.Column(elem_classes="sliders-container"): | |
| gqa_slider1 = gr.Slider(minimum=1, maximum=32, step=2, value=4, | |
| label="GQA Head Count #1") | |
| gqa_slider2 = gr.Slider(minimum=1, maximum=32, step=2, value=8, | |
| label="GQA Head Count #2") | |
| gqa_sliders.extend([gqa_slider1, gqa_slider2]) | |
| gr.Markdown("#### MLA Compressed Dimensions") | |
| gr.Markdown("*Note: MLA dimension must be less than or equal to d_model*") | |
| with gr.Column(elem_classes="sliders-container"): | |
| mla_slider1 = gr.Slider(minimum=64, maximum=1024, step=64, value=256, | |
| label="MLA Dimension #1") | |
| mla_slider2 = gr.Slider(minimum=64, maximum=1024, step=64, value=512, | |
| label="MLA Dimension #2") | |
| mla_sliders.extend([mla_slider1, mla_slider2]) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Model Configuration") | |
| num_parameters = gr.Number(label="Parameters (Billions)", value=3) | |
| parameter_size = gr.Slider(minimum=1, maximum=16.0, step=1.0, label="Parameter Size (bits per param)", value=5) | |
| kv_parameter_size = gr.Slider(minimum=0.25, maximum=4.0, step=0.25, | |
| label="KV Cache Size (bytes per value)", value=2.0) | |
| num_layers = gr.Number(label="Number of Layers", value=36) | |
| num_heads = gr.Number(label="Number of Heads", value=16, | |
| info="GQA head counts must be less than or equal to this value") | |
| d_model = gr.Number(label="D Model", value=2048, | |
| info="MLA dimensions must be less than or equal to this value") | |
| with gr.Group(): | |
| gr.Markdown("### Context Configuration") | |
| ctx_length = gr.Slider(minimum=1024, maximum=131072, step=1024, | |
| label="Max Context Length", value=65536) | |
| local_layers = gr.Number(label="Local Attention Layers", value=0) | |
| global_layers = gr.Number(label="Global Attention Layers", value=1) | |
| swa_size = gr.Slider(minimum=1024, maximum=32768, step=1024, | |
| label="Sliding Window Size", value=4096) | |
| gr.Markdown( | |
| """ | |
| For more information, see [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput). | |
| """ | |
| ) | |
| def generate_throughput_plot( | |
| model_name, iphone_model, num_parameters, parameter_size, | |
| kv_parameter_size, num_layers, num_heads, d_model, ctx_length, | |
| local_layers, global_layers, swa_size, gqa_1, gqa_2, mla_1, mla_2 | |
| ): | |
| memory_bandwidth = PhoneBandwidth[iphone_model].value | |
| if "iPhone" not in model_name: | |
| model_name = f"iPhone {iphone_model}: {model_name}" | |
| try: | |
| # Validate GQA head counts must be less than total attention heads | |
| for gqa_heads, label in [(gqa_1, "GQA Head Count #1"), (gqa_2, "GQA Head Count #2")]: | |
| if gqa_heads > num_heads: | |
| raise ValueError(f"{label} ({gqa_heads}) cannot be greater than the total number of attention heads ({num_heads})") | |
| # Validate MLA compressed dimensions must be less than d_model | |
| for mla_dim, label in [(mla_1, "MLA Dimension #1"), (mla_2, "MLA Dimension #2")]: | |
| if mla_dim > d_model: | |
| raise ValueError(f"{label} ({mla_dim}) cannot be greater than the model dimension (d_model = {d_model})") | |
| plot_img = create_throughput_plot( | |
| model_name, | |
| memory_bandwidth, | |
| num_parameters, | |
| parameter_size, | |
| kv_parameter_size, | |
| num_layers, | |
| num_heads, | |
| d_model, | |
| ctx_length, | |
| local_layers, | |
| global_layers, | |
| swa_size, | |
| [gqa_1, gqa_2], | |
| [mla_1, mla_2], | |
| ) | |
| # Hide error message, show plot | |
| return [ | |
| gr.update(value=plot_img), | |
| gr.update(visible=False, value="") | |
| ] | |
| except Exception as e: | |
| err_string = f"Error generating plot: {str(e)}" | |
| print(err_string) | |
| # Show error message, clear plot | |
| return [ | |
| gr.update(value=None), | |
| gr.update(visible=True, value=f"โ ๏ธ {err_string}") | |
| ] | |
| # Function to update GQA sliders based on number of heads | |
| def update_gqa_sliders(heads_value): | |
| if not heads_value or heads_value < 1: | |
| heads_value = 1 | |
| return [gr.update(maximum=heads_value, value=min(slider.value, heads_value)) for slider in gqa_sliders] | |
| # Function to update MLA sliders based on d_model | |
| def update_mla_sliders(d_model_value): | |
| if not d_model_value or d_model_value < 64: | |
| d_model_value = 64 | |
| return [gr.update(maximum=d_model_value, value=min(slider.value, d_model_value)) for slider in mla_sliders] | |
| # Add event handlers to update sliders when model configuration changes | |
| num_heads.change( | |
| update_gqa_sliders, | |
| inputs=[num_heads], | |
| outputs=gqa_sliders | |
| ) | |
| d_model.change( | |
| update_mla_sliders, | |
| inputs=[d_model], | |
| outputs=mla_sliders | |
| ) | |
| plot_button.click( | |
| generate_throughput_plot, | |
| inputs=[ | |
| model_name, | |
| iphone_model, | |
| num_parameters, | |
| parameter_size, | |
| kv_parameter_size, | |
| num_layers, | |
| num_heads, | |
| d_model, | |
| ctx_length, | |
| local_layers, | |
| global_layers, | |
| swa_size, | |
| *gqa_sliders, | |
| *mla_sliders, | |
| ], | |
| outputs=[plot_output, status_output] | |
| ) | |
| demo.launch() | |