moe-experts / src /App.tsx
stevhliu's picture
stevhliu HF Staff
final updates
0158b7b
import { useState } from "react";
import { BenchmarkChart, type BenchmarkRow } from "@/components/benchmark-chart";
// Parse CSV data into chart-friendly format
type ImplementationKey = "eager" | "groupedMm" | "batchedMm" | "groupedBatched";
type MetricKey = "Prefill" | "Decode" | "Total";
type NumericBenchmarkKey = Extract<
keyof BenchmarkRow,
`${ImplementationKey}${MetricKey}`
>;
function parseCSVData(
csvContent: string,
newTokens: number
): BenchmarkRow[] {
const lines = csvContent.trim().split("\n").slice(1); // Skip header
const implementations: Array<{ id: string; key: ImplementationKey }> = [
{ id: "eager", key: "eager" },
{ id: "grouped_mm", key: "groupedMm" },
{ id: "batched_mm", key: "batchedMm" },
{ id: "grouped_prefill+batched_decode", key: "groupedBatched" },
];
const parseRow = (row: string | undefined) => {
if (!row) return { prefill: 0, decode: 0 };
const parts = row.split(",");
return {
prefill: parseFloat(parts[6]) || 0,
decode: parseFloat(parts[7]) || 0,
};
};
const round2 = (value: number) => Math.round(value * 100) / 100;
const findRow = (impl: string, isCompiled: boolean) =>
lines.find((line) => {
const matchesImpl = line.includes(impl);
const matchesTokens = line.split(",")[2] === String(newTokens);
const matchesCompile = isCompiled
? line.includes("max-autotune") ||
line.includes("max-autotune-no-cudagraphs")
: line.includes(",False,");
return matchesImpl && matchesTokens && matchesCompile;
});
const buildRow = (compileGroup: string, isCompiled: boolean): BenchmarkRow => {
const row: BenchmarkRow = {
compileGroup,
eagerPrefill: 0,
eagerDecode: 0,
eagerTotal: 0,
groupedMmPrefill: 0,
groupedMmDecode: 0,
groupedMmTotal: 0,
batchedMmPrefill: 0,
batchedMmDecode: 0,
batchedMmTotal: 0,
groupedBatchedPrefill: 0,
groupedBatchedDecode: 0,
groupedBatchedTotal: 0,
};
implementations.forEach((impl) => {
const parsed = parseRow(findRow(impl.id, isCompiled));
const prefillKey = `${impl.key}Prefill` as NumericBenchmarkKey;
const decodeKey = `${impl.key}Decode` as NumericBenchmarkKey;
const totalKey = `${impl.key}Total` as NumericBenchmarkKey;
row[prefillKey] = round2(parsed.prefill);
row[decodeKey] = round2(parsed.decode);
row[totalKey] = round2(parsed.prefill + parsed.decode);
});
return row;
};
return [buildRow("no compile", false), buildRow("compiled", true)];
}
// CSV data embedded
const csvData: Record<string, string> = {
"1-16": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB)
1,16,16,False,eager,1193.70,269.94,923.76,27333.27
1,16,16,max-autotune-no-cudagraphs,eager,1332.96,269.72,1063.24,27333.27
1,16,16,False,grouped_mm,814.61,64.32,750.29,27333.27
1,16,16,max-autotune-no-cudagraphs,grouped_mm,476.82,63.42,413.41,27333.27
1,16,16,False,batched_mm,535.34,52.18,483.17,28386.68
1,16,16,max-autotune,batched_mm,144.21,52.25,91.97,28386.68
1,16,16,False,grouped_prefill+batched_decode,579.92,64.83,515.10,27396.05
1,16,16,max-autotune,grouped_prefill+batched_decode,162.56,64.71,97.84,27332.98
1,16,64,False,eager,4303.03,274.71,4028.32,27342.27
1,16,64,max-autotune-no-cudagraphs,eager,4908.60,288.97,4619.63,27343.43
1,16,64,False,grouped_mm,3300.09,64.67,3235.42,27343.43
1,16,64,max-autotune-no-cudagraphs,grouped_mm,1801.67,64.25,1737.41,27342.27
1,16,64,False,batched_mm,2151.96,52.23,2099.73,28395.68
1,16,64,max-autotune,batched_mm,434.84,52.25,382.58,28395.68
1,16,64,False,grouped_prefill+batched_decode,2217.35,64.99,2152.36,27405.06
1,16,64,max-autotune,grouped_prefill+batched_decode,465.11,64.50,400.62,27341.98`,
"1-128": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB)
1,128,16,False,eager,1452.00,480.20,971.81,27425.16
1,128,16,max-autotune-no-cudagraphs,eager,1620.98,498.00,1122.99,27410.34
1,128,16,False,grouped_mm,850.87,76.51,774.35,27425.40
1,128,16,max-autotune-no-cudagraphs,grouped_mm,492.87,76.79,416.08,27425.40
1,128,16,False,batched_mm,815.11,316.56,498.56,35866.47
1,128,16,max-autotune,batched_mm,412.98,316.33,96.65,35866.48
1,128,16,False,grouped_prefill+batched_decode,588.87,77.15,511.72,27470.51
1,128,16,max-autotune,grouped_prefill+batched_decode,181.79,76.78,105.01,27424.49
1,128,64,False,eager,4524.16,486.84,4037.31,27418.44
1,128,64,max-autotune-no-cudagraphs,eager,5116.71,477.42,4639.29,27419.62
1,128,64,False,grouped_mm,3327.45,76.46,3250.98,27434.67
1,128,64,max-autotune-no-cudagraphs,grouped_mm,1824.70,76.55,1748.15,27433.49
1,128,64,False,batched_mm,2411.23,316.18,2095.05,35875.48
1,128,64,max-autotune,batched_mm,707.73,316.24,391.49,35875.48
1,128,64,False,grouped_prefill+batched_decode,2219.34,76.89,2142.45,27479.51
1,128,64,max-autotune,grouped_prefill+batched_decode,489.06,76.88,412.18,27433.50`,
"4-16": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB)
4,16,16,False,eager,2412.69,432.96,1979.74,27420.04
4,16,16,max-autotune-no-cudagraphs,eager,2899.32,428.52,2470.80,27384.22
4,16,16,False,grouped_mm,923.69,74.45,849.24,27384.22
4,16,16,max-autotune-no-cudagraphs,grouped_mm,593.89,75.93,517.97,27384.22
4,16,16,False,batched_mm,673.48,164.33,509.16,31601.36
4,16,16,max-autotune,batched_mm,330.13,164.50,165.63,31601.36
4,16,16,False,grouped_prefill+batched_decode,599.51,74.42,525.09,27638.76
4,16,16,max-autotune,grouped_prefill+batched_decode,240.94,74.37,166.57,27382.55
4,16,64,False,eager,9249.34,429.71,8819.63,27448.23
4,16,64,max-autotune-no-cudagraphs,eager,11396.42,428.78,10967.64,27429.62
4,16,64,False,grouped_mm,3649.18,74.14,3575.05,27429.62
4,16,64,max-autotune-no-cudagraphs,grouped_mm,2264.07,74.28,2189.79,27424.98
4,16,64,False,batched_mm,2309.70,164.35,2145.35,31642.11
4,16,64,max-autotune,batched_mm,846.58,164.50,682.08,31642.11
4,16,64,False,grouped_prefill+batched_decode,2270.82,74.18,2196.64,27679.51
4,16,64,max-autotune,grouped_prefill+batched_decode,776.50,73.13,703.36,27423.29`,
"4-128": `Batch Size,Seq Length,New Tokens,Torch Compile,Implementation,Mean Generation Latency (ms),Mean Prefill Latency (ms),Mean Decode Latency (ms),Peak Mem (MB)
4,128,16,False,eager,2147.71,523.33,1624.38,27691.78
4,128,16,max-autotune-no-cudagraphs,eager,2480.61,519.22,1961.39,27696.76
4,128,16,False,grouped_mm,908.79,80.24,828.55,27756.97
4,128,16,max-autotune-no-cudagraphs,grouped_mm,577.16,79.94,497.22,27751.99
4,128,16,False,batched_mm,1744.98,1235.64,509.34,61519.94
4,128,16,max-autotune,batched_mm,1404.29,1236.12,168.17,61519.94
4,128,16,False,grouped_prefill+batched_decode,603.57,80.56,523.01,27936.07
4,128,16,max-autotune,grouped_prefill+batched_decode,249.57,79.61,169.96,27751.99
4,128,64,False,eager,7273.57,520.92,6752.65,27727.80
4,128,64,max-autotune-no-cudagraphs,eager,8629.54,525.01,8104.53,27732.78
4,128,64,False,grouped_mm,3531.64,80.73,3450.91,27792.99
4,128,64,max-autotune-no-cudagraphs,grouped_mm,2136.08,80.10,2055.98,27792.99
4,128,64,False,batched_mm,3365.31,1235.87,2129.44,61560.94
4,128,64,max-autotune,batched_mm,1931.15,1236.34,694.81,61555.96
4,128,64,False,grouped_prefill+batched_decode,2264.46,80.26,2184.20,27972.08
4,128,64,max-autotune,grouped_prefill+batched_decode,785.26,79.97,705.28,27788.02`,
};
function App() {
const [batchSize, setBatchSize] = useState(4);
const [seqLength, setSeqLength] = useState(16);
const [newTokens, setNewTokens] = useState(64);
const key = `${batchSize}-${seqLength}`;
const dataTokens = parseCSVData(csvData[key], newTokens);
const speedupTarget =
batchSize === 1 && seqLength === 16 ? "batchedMm" : "groupedBatched";
const noCompileRow = dataTokens.find(
(row) => row.compileGroup === "no compile"
);
const compiledRow = dataTokens.find(
(row) => row.compileGroup === "compiled"
);
const eagerNoCompile = noCompileRow?.eagerTotal;
const compiledTarget =
speedupTarget === "batchedMm"
? compiledRow?.batchedMmTotal
: compiledRow?.groupedBatchedTotal;
const speedup =
eagerNoCompile && compiledTarget
? eagerNoCompile / compiledTarget
: undefined;
const speedupLabel = speedup ? `${speedup.toFixed(1)}x` : "—";
const speedupText =
speedupTarget === "batchedMm"
? "eager no compile → batched_mm compiled"
: "eager no compile → grouped+batched compiled";
const legendItems = [
{
label: "prefill",
colors: ["#2563eb", "#16a34a", "#7c3aed", "#ea580c"],
},
{
label: "decode",
colors: ["#93c5fd", "#86efac", "#d8b4fe", "#fdba74"],
},
];
return (
<div className="min-h-screen bg-white py-12 px-4">
<div className="max-w-7xl mx-auto">
<section>
<div className="mb-6 text-center">
<h2 className="text-2xl font-semibold text-slate-900">
batch size{" "}
<button
onClick={() => setBatchSize(batchSize === 1 ? 4 : 1)}
className="underline decoration-slate-300 underline-offset-4 hover:decoration-slate-900 transition-colors"
>
{batchSize}
</button>
, sequence length{" "}
<button
onClick={() => setSeqLength(seqLength === 16 ? 128 : 16)}
className="underline decoration-slate-300 underline-offset-4 hover:decoration-slate-900 transition-colors"
>
{seqLength}
</button>
, new tokens{" "}
<button
onClick={() => setNewTokens(newTokens === 16 ? 64 : 16)}
className="underline decoration-slate-300 underline-offset-4 hover:decoration-slate-900 transition-colors"
>
{newTokens}
</button>
</h2>
<div className="mt-2 flex flex-wrap items-center justify-center gap-2">
<span className="text-xs uppercase tracking-wide text-slate-500">
{speedupText}
</span>
<span className="text-lg font-semibold text-slate-900">
{speedupLabel}
</span>
</div>
</div>
<div className="max-w-3xl mx-auto">
<BenchmarkChart
data={dataTokens}
/>
</div>
</section>
<div className="flex flex-wrap items-center justify-center gap-x-6 gap-y-2 mt-8">
{legendItems.map((item) => (
<div key={item.label} className="flex items-center gap-2">
<div className="flex items-center gap-1">
{item.colors.map((color) => (
<div
key={color}
className="h-3 w-3 rounded-sm"
style={{ backgroundColor: color }}
/>
))}
</div>
<span className="text-sm text-slate-600">{item.label}</span>
</div>
))}
</div>
</div>
</div>
);
}
export default App;