illustrated-cluster / src /lib /workbench.ts
joeddav's picture
Publish WIP HF Space snapshot
1f77aa7
import {
a100_80gb,
b300,
gb200,
h100_sxm,
llama31_405B,
olmo3_32B,
trinityLarge400B,
type ClusterConfig,
type GPUSpec,
type ModelConfig,
type ParallelismConfig,
type TrainingConfig,
} from './trainingClusterModel'
export type ExamplePresetId = 'olmo3-32b' | 'llama31-405b' | 'trinity-large-400b'
export type ExamplePhaseId = 'pretraining' | 'long-context'
export type GpuPresetId = 'a100-80gb' | 'h100-sxm' | 'b300' | 'gb200'
export type WorkbenchScenarioId =
| 'default'
| 'olmo-pretraining'
| 'olmo-long-context'
| 'llama-pretraining'
| 'llama-long-context'
| 'trinity-pretraining'
| 'trinity-long-context'
| 'infeasible-memory'
export type WorkbenchConfig = {
examplePresetId: ExamplePresetId
phaseId: ExamplePhaseId
customized: boolean
model: ModelConfig
training: TrainingConfig
cluster: ClusterConfig
parallelism: ParallelismConfig
}
type ExamplePhaseConfig = {
cluster: ClusterConfig
training: TrainingConfig
parallelism: ParallelismConfig
}
type ExamplePreset = {
label: string
model: () => ModelConfig
phases: Record<ExamplePhaseId, ExamplePhaseConfig>
}
const GPU_PRESETS: Record<GpuPresetId, { label: string; spec: () => GPUSpec }> = {
'a100-80gb': {
label: 'A100 80GB',
spec: a100_80gb,
},
'h100-sxm': {
label: 'H100 SXM',
spec: h100_sxm,
},
b300: {
label: 'B300',
spec: b300,
},
gb200: {
label: 'GB200',
spec: gb200,
},
}
const gpuPresetMatches = (candidate: GPUSpec, preset: GPUSpec) =>
candidate.name === preset.name &&
candidate.hbmCapacityGB === preset.hbmCapacityGB &&
candidate.peakTFLOPsBF16 === preset.peakTFLOPsBF16 &&
candidate.memBandwidthTBs === preset.memBandwidthTBs
const h100Cluster = (numNodes: number, nodesPerRack: number): ClusterConfig => ({
gpuType: h100_sxm(),
gpusPerNode: 8,
numNodes,
intraNodeBandwidthGBs: 900,
interNodeBandwidthGBs: 50,
nodesPerRack,
rackLabel: 'rack',
nodeLabel: 'GPU host',
podLabel: 'rack',
})
const b300Cluster = (numNodes: number, nodesPerRack: number): ClusterConfig => ({
gpuType: b300(),
gpusPerNode: 8,
numNodes,
intraNodeBandwidthGBs: 900,
interNodeBandwidthGBs: 50,
nodesPerRack,
rackLabel: 'rack',
nodeLabel: 'GPU host',
podLabel: 'rack',
})
export const EXAMPLE_PRESETS: Record<ExamplePresetId, ExamplePreset> = {
'olmo3-32b': {
label: 'OLMo 3 32B',
model: olmo3_32B,
phases: {
pretraining: {
cluster: h100Cluster(128, 16),
training: {
microBatchSize: 1,
seqLength: 8192,
gradAccumSteps: 1,
precision: 'bf16',
activationCheckpointing: true,
optimizer: 'adamw',
},
parallelism: {
tp: 1,
pp: 1,
cp: 1,
ep: 1,
distributedOptimizer: true,
fsdpShardGroupSize: 256,
zeroStage: 3,
},
},
'long-context': {
cluster: h100Cluster(32, 8),
training: {
microBatchSize: 1,
seqLength: 65536,
gradAccumSteps: 1,
precision: 'bf16',
activationCheckpointing: true,
optimizer: 'adamw',
},
parallelism: {
tp: 1,
pp: 1,
cp: 8,
ep: 1,
distributedOptimizer: true,
fsdpShardGroupSize: 256,
zeroStage: 3,
},
},
},
},
'llama31-405b': {
label: 'Llama 3.1 405B',
model: llama31_405B,
phases: {
pretraining: {
cluster: h100Cluster(2048, 16),
training: {
microBatchSize: 1,
seqLength: 8192,
gradAccumSteps: 16,
precision: 'bf16',
activationCheckpointing: true,
optimizer: 'adamw',
},
parallelism: {
tp: 8,
pp: 16,
cp: 1,
ep: 1,
distributedOptimizer: true,
fsdpShardGroupSize: 0,
zeroStage: 1,
},
},
'long-context': {
cluster: h100Cluster(2048, 16),
training: {
microBatchSize: 1,
seqLength: 131072,
gradAccumSteps: 1,
precision: 'bf16',
activationCheckpointing: true,
optimizer: 'adamw',
},
parallelism: {
tp: 8,
pp: 16,
cp: 16,
ep: 1,
distributedOptimizer: true,
fsdpShardGroupSize: 0,
zeroStage: 1,
},
},
},
},
'trinity-large-400b': {
label: 'Trinity Large 400B',
model: trinityLarge400B,
phases: {
pretraining: {
cluster: b300Cluster(256, 9),
training: {
microBatchSize: 1,
seqLength: 8192,
gradAccumSteps: 8,
precision: 'bf16',
activationCheckpointing: true,
optimizer: 'muon',
},
parallelism: {
tp: 1,
pp: 1,
cp: 1,
ep: 8,
distributedOptimizer: true,
fsdpShardGroupSize: 128,
zeroStage: 3,
},
},
'long-context': {
cluster: b300Cluster(256, 9),
training: {
microBatchSize: 1,
seqLength: 262144,
gradAccumSteps: 1,
precision: 'bf16',
activationCheckpointing: true,
optimizer: 'muon',
},
parallelism: {
tp: 1,
pp: 1,
cp: 4,
ep: 8,
distributedOptimizer: true,
fsdpShardGroupSize: 128,
zeroStage: 3,
},
},
},
},
}
const createWorkbenchConfig = (
examplePresetId: ExamplePresetId,
phaseId: ExamplePhaseId,
): WorkbenchConfig => {
const preset = EXAMPLE_PRESETS[examplePresetId]
const phase = preset.phases[phaseId]
return {
examplePresetId,
phaseId,
customized: false,
model: preset.model(),
training: { ...phase.training },
cluster: { ...phase.cluster },
parallelism: { ...phase.parallelism },
}
}
const SCENARIOS: Record<WorkbenchScenarioId, WorkbenchConfig> = {
default: createWorkbenchConfig('olmo3-32b', 'pretraining'),
'olmo-pretraining': createWorkbenchConfig('olmo3-32b', 'pretraining'),
'olmo-long-context': createWorkbenchConfig('olmo3-32b', 'long-context'),
'llama-pretraining': createWorkbenchConfig('llama31-405b', 'pretraining'),
'llama-long-context': createWorkbenchConfig('llama31-405b', 'long-context'),
'trinity-pretraining': createWorkbenchConfig('trinity-large-400b', 'pretraining'),
'trinity-long-context': createWorkbenchConfig('trinity-large-400b', 'long-context'),
'infeasible-memory': {
examplePresetId: 'llama31-405b',
phaseId: 'pretraining',
customized: false,
model: llama31_405B(),
training: {
microBatchSize: 1,
seqLength: 8192,
gradAccumSteps: 1,
precision: 'bf16',
activationCheckpointing: true,
optimizer: 'adamw',
},
cluster: h100Cluster(8, 4),
parallelism: {
tp: 8,
pp: 1,
cp: 1,
ep: 1,
distributedOptimizer: false,
fsdpShardGroupSize: 0,
zeroStage: 0,
},
},
}
const cloneModel = (model: ModelConfig): ModelConfig => ({
...model,
attentionProfile: model.attentionProfile ? { ...model.attentionProfile } : undefined,
moe: model.moe ? { ...model.moe } : undefined,
})
const cloneTraining = (training: TrainingConfig): TrainingConfig => ({ ...training })
const cloneCluster = (cluster: ClusterConfig): ClusterConfig => ({ ...cluster })
const cloneParallelism = (parallelism: ParallelismConfig): ParallelismConfig => ({
...parallelism,
})
export const cloneWorkbenchConfig = (config: WorkbenchConfig): WorkbenchConfig => ({
examplePresetId: config.examplePresetId,
phaseId: config.phaseId,
customized: config.customized,
model: cloneModel(config.model),
training: cloneTraining(config.training),
cluster: cloneCluster(config.cluster),
parallelism: cloneParallelism(config.parallelism),
})
export function getScenarioWorkbenchConfig(scenario: WorkbenchScenarioId) {
return cloneWorkbenchConfig(SCENARIOS[scenario])
}
export function getExamplePresetOptions() {
return Object.entries(EXAMPLE_PRESETS)
.filter(([id]) => id !== 'llama31-405b')
.map(([id, preset]) => ({
id: id as ExamplePresetId,
label: preset.label,
}))
}
export function getPhaseOptions(examplePresetId: ExamplePresetId) {
const preset = EXAMPLE_PRESETS[examplePresetId]
return Object.keys(preset.phases).map((phaseId) => ({
id: phaseId as ExamplePhaseId,
label: phaseId === 'pretraining' ? 'Pretraining' : 'Long-context',
}))
}
export function getExampleLabel(examplePresetId: ExamplePresetId) {
return EXAMPLE_PRESETS[examplePresetId].label
}
export function getGpuPresetOptions() {
return Object.entries(GPU_PRESETS).map(([id, preset]) => ({
id: id as GpuPresetId,
label: preset.label,
}))
}
export function getGpuPresetId(gpuType: GPUSpec): GpuPresetId | 'custom' {
for (const [id, preset] of Object.entries(GPU_PRESETS)) {
if (gpuPresetMatches(gpuType, preset.spec())) {
return id as GpuPresetId
}
}
return 'custom'
}
export function applyGpuPreset(config: WorkbenchConfig, gpuPresetId: GpuPresetId): WorkbenchConfig {
return {
...config,
customized: true,
cluster: {
...config.cluster,
gpuType: GPU_PRESETS[gpuPresetId].spec(),
},
}
}
export function applyExamplePreset(
_config: WorkbenchConfig,
examplePresetId: ExamplePresetId,
): WorkbenchConfig {
return createWorkbenchConfig(examplePresetId, 'pretraining')
}
export function applyExamplePhase(
config: WorkbenchConfig,
phaseId: ExamplePhaseId,
): WorkbenchConfig {
return createWorkbenchConfig(config.examplePresetId, phaseId)
}
export function getFactorOptions(total: number, currentValue: number) {
const factors = new Set<number>([currentValue])
for (let candidate = 1; candidate <= total; candidate += 1) {
if (total % candidate === 0) {
factors.add(candidate)
}
}
return Array.from(factors).sort((left, right) => left - right)
}