Spaces:
Configuration error
Configuration error
| import { useState } from "react"; | |
| import { BenchmarkChart, type BenchmarkRow } from "@/components/benchmark-chart"; | |
| // Parse CSV data into chart-friendly format | |
| type ImplementationKey = "eager" | "groupedMm" | "batchedMm" | "groupedBatched"; | |
| type MetricKey = "Prefill" | "Decode" | "Total"; | |
| type NumericBenchmarkKey = Extract< | |
| keyof BenchmarkRow, | |
| `${ImplementationKey}${MetricKey}` | |
| >; | |
| function parseCSVData( | |
| csvContent: string, | |
| newTokens: number | |
| ): BenchmarkRow[] { | |
| const lines = csvContent.trim().split("\n").slice(1); // Skip header | |
| const implementations: Array<{ id: string; key: ImplementationKey }> = [ | |
| { id: "eager", key: "eager" }, | |
| { id: "grouped_mm", key: "groupedMm" }, | |
| { id: "batched_mm", key: "batchedMm" }, | |
| { id: "grouped_prefill+batched_decode", key: "groupedBatched" }, | |
| ]; | |
| const parseRow = (row: string | undefined) => { | |
| if (!row) return { prefill: 0, decode: 0 }; | |
| const parts = row.split(","); | |
| return { | |
| prefill: parseFloat(parts[6]) || 0, | |
| decode: parseFloat(parts[7]) || 0, | |
| }; | |
| }; | |
| const round2 = (value: number) => Math.round(value * 100) / 100; | |
| const findRow = (impl: string, isCompiled: boolean) => | |
| lines.find((line) => { | |
| const matchesImpl = line.includes(impl); | |
| const matchesTokens = line.split(",")[2] === String(newTokens); | |
| const matchesCompile = isCompiled | |
| ? line.includes("max-autotune") || | |
| line.includes("max-autotune-no-cudagraphs") | |
| : line.includes(",False,"); | |
| return matchesImpl && matchesTokens && matchesCompile; | |
| }); | |
| const buildRow = (compileGroup: string, isCompiled: boolean): BenchmarkRow => { | |
| const row: BenchmarkRow = { | |
| compileGroup, | |
| eagerPrefill: 0, | |
| eagerDecode: 0, | |
| eagerTotal: 0, | |
| groupedMmPrefill: 0, | |
| groupedMmDecode: 0, | |
| groupedMmTotal: 0, | |
| batchedMmPrefill: 0, | |
| batchedMmDecode: 0, | |
| batchedMmTotal: 0, | |
| groupedBatchedPrefill: 0, | |
| groupedBatchedDecode: 0, | |
| groupedBatchedTotal: 0, | |
| }; | |
| implementations.forEach((impl) => { | |
| const parsed = parseRow(findRow(impl.id, isCompiled)); | |
| const prefillKey = `${impl.key}Prefill` as NumericBenchmarkKey; | |
| const decodeKey = `${impl.key}Decode` as NumericBenchmarkKey; | |
| const totalKey = `${impl.key}Total` as NumericBenchmarkKey; | |
| row[prefillKey] = round2(parsed.prefill); | |
| row[decodeKey] = round2(parsed.decode); | |
| row[totalKey] = round2(parsed.prefill + parsed.decode); | |
| }); | |
| return row; | |
| }; | |
| return [buildRow("no compile", false), buildRow("compiled", true)]; | |
| } | |
| // CSV data embedded | |
| const csvData: Record<string, string> = { | |
| "1-16": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB) | |
| 1,16,16,False,eager,1193.70,269.94,923.76,27333.27 | |
| 1,16,16,max-autotune-no-cudagraphs,eager,1332.96,269.72,1063.24,27333.27 | |
| 1,16,16,False,grouped_mm,814.61,64.32,750.29,27333.27 | |
| 1,16,16,max-autotune-no-cudagraphs,grouped_mm,476.82,63.42,413.41,27333.27 | |
| 1,16,16,False,batched_mm,535.34,52.18,483.17,28386.68 | |
| 1,16,16,max-autotune,batched_mm,144.21,52.25,91.97,28386.68 | |
| 1,16,16,False,grouped_prefill+batched_decode,579.92,64.83,515.10,27396.05 | |
| 1,16,16,max-autotune,grouped_prefill+batched_decode,162.56,64.71,97.84,27332.98 | |
| 1,16,64,False,eager,4303.03,274.71,4028.32,27342.27 | |
| 1,16,64,max-autotune-no-cudagraphs,eager,4908.60,288.97,4619.63,27343.43 | |
| 1,16,64,False,grouped_mm,3300.09,64.67,3235.42,27343.43 | |
| 1,16,64,max-autotune-no-cudagraphs,grouped_mm,1801.67,64.25,1737.41,27342.27 | |
| 1,16,64,False,batched_mm,2151.96,52.23,2099.73,28395.68 | |
| 1,16,64,max-autotune,batched_mm,434.84,52.25,382.58,28395.68 | |
| 1,16,64,False,grouped_prefill+batched_decode,2217.35,64.99,2152.36,27405.06 | |
| 1,16,64,max-autotune,grouped_prefill+batched_decode,465.11,64.50,400.62,27341.98`, | |
| "1-128": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB) | |
| 1,128,16,False,eager,1452.00,480.20,971.81,27425.16 | |
| 1,128,16,max-autotune-no-cudagraphs,eager,1620.98,498.00,1122.99,27410.34 | |
| 1,128,16,False,grouped_mm,850.87,76.51,774.35,27425.40 | |
| 1,128,16,max-autotune-no-cudagraphs,grouped_mm,492.87,76.79,416.08,27425.40 | |
| 1,128,16,False,batched_mm,815.11,316.56,498.56,35866.47 | |
| 1,128,16,max-autotune,batched_mm,412.98,316.33,96.65,35866.48 | |
| 1,128,16,False,grouped_prefill+batched_decode,588.87,77.15,511.72,27470.51 | |
| 1,128,16,max-autotune,grouped_prefill+batched_decode,181.79,76.78,105.01,27424.49 | |
| 1,128,64,False,eager,4524.16,486.84,4037.31,27418.44 | |
| 1,128,64,max-autotune-no-cudagraphs,eager,5116.71,477.42,4639.29,27419.62 | |
| 1,128,64,False,grouped_mm,3327.45,76.46,3250.98,27434.67 | |
| 1,128,64,max-autotune-no-cudagraphs,grouped_mm,1824.70,76.55,1748.15,27433.49 | |
| 1,128,64,False,batched_mm,2411.23,316.18,2095.05,35875.48 | |
| 1,128,64,max-autotune,batched_mm,707.73,316.24,391.49,35875.48 | |
| 1,128,64,False,grouped_prefill+batched_decode,2219.34,76.89,2142.45,27479.51 | |
| 1,128,64,max-autotune,grouped_prefill+batched_decode,489.06,76.88,412.18,27433.50`, | |
| "4-16": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB) | |
| 4,16,16,False,eager,2412.69,432.96,1979.74,27420.04 | |
| 4,16,16,max-autotune-no-cudagraphs,eager,2899.32,428.52,2470.80,27384.22 | |
| 4,16,16,False,grouped_mm,923.69,74.45,849.24,27384.22 | |
| 4,16,16,max-autotune-no-cudagraphs,grouped_mm,593.89,75.93,517.97,27384.22 | |
| 4,16,16,False,batched_mm,673.48,164.33,509.16,31601.36 | |
| 4,16,16,max-autotune,batched_mm,330.13,164.50,165.63,31601.36 | |
| 4,16,16,False,grouped_prefill+batched_decode,599.51,74.42,525.09,27638.76 | |
| 4,16,16,max-autotune,grouped_prefill+batched_decode,240.94,74.37,166.57,27382.55 | |
| 4,16,64,False,eager,9249.34,429.71,8819.63,27448.23 | |
| 4,16,64,max-autotune-no-cudagraphs,eager,11396.42,428.78,10967.64,27429.62 | |
| 4,16,64,False,grouped_mm,3649.18,74.14,3575.05,27429.62 | |
| 4,16,64,max-autotune-no-cudagraphs,grouped_mm,2264.07,74.28,2189.79,27424.98 | |
| 4,16,64,False,batched_mm,2309.70,164.35,2145.35,31642.11 | |
| 4,16,64,max-autotune,batched_mm,846.58,164.50,682.08,31642.11 | |
| 4,16,64,False,grouped_prefill+batched_decode,2270.82,74.18,2196.64,27679.51 | |
| 4,16,64,max-autotune,grouped_prefill+batched_decode,776.50,73.13,703.36,27423.29`, | |
| "4-128": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB) | |
| 4,128,16,False,eager,2147.71,523.33,1624.38,27691.78 | |
| 4,128,16,max-autotune-no-cudagraphs,eager,2480.61,519.22,1961.39,27696.76 | |
| 4,128,16,False,grouped_mm,908.79,80.24,828.55,27756.97 | |
| 4,128,16,max-autotune-no-cudagraphs,grouped_mm,577.16,79.94,497.22,27751.99 | |
| 4,128,16,False,batched_mm,1744.98,1235.64,509.34,61519.94 | |
| 4,128,16,max-autotune,batched_mm,1404.29,1236.12,168.17,61519.94 | |
| 4,128,16,False,grouped_prefill+batched_decode,603.57,80.56,523.01,27936.07 | |
| 4,128,16,max-autotune,grouped_prefill+batched_decode,249.57,79.61,169.96,27751.99 | |
| 4,128,64,False,eager,7273.57,520.92,6752.65,27727.80 | |
| 4,128,64,max-autotune-no-cudagraphs,eager,8629.54,525.01,8104.53,27732.78 | |
| 4,128,64,False,grouped_mm,3531.64,80.73,3450.91,27792.99 | |
| 4,128,64,max-autotune-no-cudagraphs,grouped_mm,2136.08,80.10,2055.98,27792.99 | |
| 4,128,64,False,batched_mm,3365.31,1235.87,2129.44,61560.94 | |
| 4,128,64,max-autotune,batched_mm,1931.15,1236.34,694.81,61555.96 | |
| 4,128,64,False,grouped_prefill+batched_decode,2264.46,80.26,2184.20,27972.08 | |
| 4,128,64,max-autotune,grouped_prefill+batched_decode,785.26,79.97,705.28,27788.02`, | |
| }; | |
| function App() { | |
| const [batchSize, setBatchSize] = useState(4); | |
| const [seqLength, setSeqLength] = useState(16); | |
| const [newTokens, setNewTokens] = useState(64); | |
| const key = `${batchSize}-${seqLength}`; | |
| const dataTokens = parseCSVData(csvData[key], newTokens); | |
| const speedupTarget = | |
| batchSize === 1 && seqLength === 16 ? "batchedMm" : "groupedBatched"; | |
| const noCompileRow = dataTokens.find( | |
| (row) => row.compileGroup === "no compile" | |
| ); | |
| const compiledRow = dataTokens.find( | |
| (row) => row.compileGroup === "compiled" | |
| ); | |
| const eagerNoCompile = noCompileRow?.eagerTotal; | |
| const compiledTarget = | |
| speedupTarget === "batchedMm" | |
| ? compiledRow?.batchedMmTotal | |
| : compiledRow?.groupedBatchedTotal; | |
| const speedup = | |
| eagerNoCompile && compiledTarget | |
| ? eagerNoCompile / compiledTarget | |
| : undefined; | |
| const speedupLabel = speedup ? `${speedup.toFixed(1)}x` : "—"; | |
| const speedupText = | |
| speedupTarget === "batchedMm" | |
| ? "eager no compile → batched_mm compiled" | |
| : "eager no compile → grouped+batched compiled"; | |
| const legendItems = [ | |
| { | |
| label: "prefill", | |
| colors: ["#2563eb", "#16a34a", "#7c3aed", "#ea580c"], | |
| }, | |
| { | |
| label: "decode", | |
| colors: ["#93c5fd", "#86efac", "#d8b4fe", "#fdba74"], | |
| }, | |
| ]; | |
| return ( | |
| <div className="min-h-screen bg-white py-12 px-4"> | |
| <div className="max-w-7xl mx-auto"> | |
| <section> | |
| <div className="mb-6 text-center"> | |
| <h2 className="text-2xl font-semibold text-slate-900"> | |
| batch size{" "} | |
| <button | |
| onClick={() => setBatchSize(batchSize === 1 ? 4 : 1)} | |
| className="underline decoration-slate-300 underline-offset-4 hover:decoration-slate-900 transition-colors" | |
| > | |
| {batchSize} | |
| </button> | |
| , sequence length{" "} | |
| <button | |
| onClick={() => setSeqLength(seqLength === 16 ? 128 : 16)} | |
| className="underline decoration-slate-300 underline-offset-4 hover:decoration-slate-900 transition-colors" | |
| > | |
| {seqLength} | |
| </button> | |
| , new tokens{" "} | |
| <button | |
| onClick={() => setNewTokens(newTokens === 16 ? 64 : 16)} | |
| className="underline decoration-slate-300 underline-offset-4 hover:decoration-slate-900 transition-colors" | |
| > | |
| {newTokens} | |
| </button> | |
| </h2> | |
| <div className="mt-2 flex flex-wrap items-center justify-center gap-2"> | |
| <span className="text-xs uppercase tracking-wide text-slate-500"> | |
| {speedupText} | |
| </span> | |
| <span className="text-lg font-semibold text-slate-900"> | |
| {speedupLabel} | |
| </span> | |
| </div> | |
| </div> | |
| <div className="max-w-3xl mx-auto"> | |
| <BenchmarkChart | |
| data={dataTokens} | |
| /> | |
| </div> | |
| </section> | |
| <div className="flex flex-wrap items-center justify-center gap-x-6 gap-y-2 mt-8"> | |
| {legendItems.map((item) => ( | |
| <div key={item.label} className="flex items-center gap-2"> | |
| <div className="flex items-center gap-1"> | |
| {item.colors.map((color) => ( | |
| <div | |
| key={color} | |
| className="h-3 w-3 rounded-sm" | |
| style={{ backgroundColor: color }} | |
| /> | |
| ))} | |
| </div> | |
| <span className="text-sm text-slate-600">{item.label}</span> | |
| </div> | |
| ))} | |
| </div> | |
| </div> | |
| </div> | |
| ); | |
| } | |
| export default App; | |