import { a100_80gb, b300, gb200, h100_sxm, llama31_405B, olmo3_32B, trinityLarge400B, type ClusterConfig, type GPUSpec, type ModelConfig, type ParallelismConfig, type TrainingConfig, } from './trainingClusterModel' export type ExamplePresetId = 'olmo3-32b' | 'llama31-405b' | 'trinity-large-400b' export type ExamplePhaseId = 'pretraining' | 'long-context' export type GpuPresetId = 'a100-80gb' | 'h100-sxm' | 'b300' | 'gb200' export type WorkbenchScenarioId = | 'default' | 'olmo-pretraining' | 'olmo-long-context' | 'llama-pretraining' | 'llama-long-context' | 'trinity-pretraining' | 'trinity-long-context' | 'infeasible-memory' export type WorkbenchConfig = { examplePresetId: ExamplePresetId phaseId: ExamplePhaseId customized: boolean model: ModelConfig training: TrainingConfig cluster: ClusterConfig parallelism: ParallelismConfig } type ExamplePhaseConfig = { cluster: ClusterConfig training: TrainingConfig parallelism: ParallelismConfig } type ExamplePreset = { label: string model: () => ModelConfig phases: Record } const GPU_PRESETS: Record GPUSpec }> = { 'a100-80gb': { label: 'A100 80GB', spec: a100_80gb, }, 'h100-sxm': { label: 'H100 SXM', spec: h100_sxm, }, b300: { label: 'B300', spec: b300, }, gb200: { label: 'GB200', spec: gb200, }, } const gpuPresetMatches = (candidate: GPUSpec, preset: GPUSpec) => candidate.name === preset.name && candidate.hbmCapacityGB === preset.hbmCapacityGB && candidate.peakTFLOPsBF16 === preset.peakTFLOPsBF16 && candidate.memBandwidthTBs === preset.memBandwidthTBs const h100Cluster = (numNodes: number, nodesPerRack: number): ClusterConfig => ({ gpuType: h100_sxm(), gpusPerNode: 8, numNodes, intraNodeBandwidthGBs: 900, interNodeBandwidthGBs: 50, nodesPerRack, rackLabel: 'rack', nodeLabel: 'GPU host', podLabel: 'rack', }) const b300Cluster = (numNodes: number, nodesPerRack: number): ClusterConfig => ({ gpuType: b300(), gpusPerNode: 8, numNodes, intraNodeBandwidthGBs: 900, interNodeBandwidthGBs: 50, nodesPerRack, rackLabel: 'rack', nodeLabel: 'GPU host', podLabel: 'rack', }) export const EXAMPLE_PRESETS: Record = { 'olmo3-32b': { label: 'OLMo 3 32B', model: olmo3_32B, phases: { pretraining: { cluster: h100Cluster(128, 16), training: { microBatchSize: 1, seqLength: 8192, gradAccumSteps: 1, precision: 'bf16', activationCheckpointing: true, optimizer: 'adamw', }, parallelism: { tp: 1, pp: 1, cp: 1, ep: 1, distributedOptimizer: true, fsdpShardGroupSize: 256, zeroStage: 3, }, }, 'long-context': { cluster: h100Cluster(32, 8), training: { microBatchSize: 1, seqLength: 65536, gradAccumSteps: 1, precision: 'bf16', activationCheckpointing: true, optimizer: 'adamw', }, parallelism: { tp: 1, pp: 1, cp: 8, ep: 1, distributedOptimizer: true, fsdpShardGroupSize: 256, zeroStage: 3, }, }, }, }, 'llama31-405b': { label: 'Llama 3.1 405B', model: llama31_405B, phases: { pretraining: { cluster: h100Cluster(2048, 16), training: { microBatchSize: 1, seqLength: 8192, gradAccumSteps: 16, precision: 'bf16', activationCheckpointing: true, optimizer: 'adamw', }, parallelism: { tp: 8, pp: 16, cp: 1, ep: 1, distributedOptimizer: true, fsdpShardGroupSize: 0, zeroStage: 1, }, }, 'long-context': { cluster: h100Cluster(2048, 16), training: { microBatchSize: 1, seqLength: 131072, gradAccumSteps: 1, precision: 'bf16', activationCheckpointing: true, optimizer: 'adamw', }, parallelism: { tp: 8, pp: 16, cp: 16, ep: 1, distributedOptimizer: true, fsdpShardGroupSize: 0, zeroStage: 1, }, }, }, }, 'trinity-large-400b': { label: 'Trinity Large 400B', model: trinityLarge400B, phases: { pretraining: { cluster: b300Cluster(256, 9), training: { microBatchSize: 1, seqLength: 8192, gradAccumSteps: 8, precision: 'bf16', activationCheckpointing: true, optimizer: 'muon', }, parallelism: { tp: 1, pp: 1, cp: 1, ep: 8, distributedOptimizer: true, fsdpShardGroupSize: 128, zeroStage: 3, }, }, 'long-context': { cluster: b300Cluster(256, 9), training: { microBatchSize: 1, seqLength: 262144, gradAccumSteps: 1, precision: 'bf16', activationCheckpointing: true, optimizer: 'muon', }, parallelism: { tp: 1, pp: 1, cp: 4, ep: 8, distributedOptimizer: true, fsdpShardGroupSize: 128, zeroStage: 3, }, }, }, }, } const createWorkbenchConfig = ( examplePresetId: ExamplePresetId, phaseId: ExamplePhaseId, ): WorkbenchConfig => { const preset = EXAMPLE_PRESETS[examplePresetId] const phase = preset.phases[phaseId] return { examplePresetId, phaseId, customized: false, model: preset.model(), training: { ...phase.training }, cluster: { ...phase.cluster }, parallelism: { ...phase.parallelism }, } } const SCENARIOS: Record = { default: createWorkbenchConfig('olmo3-32b', 'pretraining'), 'olmo-pretraining': createWorkbenchConfig('olmo3-32b', 'pretraining'), 'olmo-long-context': createWorkbenchConfig('olmo3-32b', 'long-context'), 'llama-pretraining': createWorkbenchConfig('llama31-405b', 'pretraining'), 'llama-long-context': createWorkbenchConfig('llama31-405b', 'long-context'), 'trinity-pretraining': createWorkbenchConfig('trinity-large-400b', 'pretraining'), 'trinity-long-context': createWorkbenchConfig('trinity-large-400b', 'long-context'), 'infeasible-memory': { examplePresetId: 'llama31-405b', phaseId: 'pretraining', customized: false, model: llama31_405B(), training: { microBatchSize: 1, seqLength: 8192, gradAccumSteps: 1, precision: 'bf16', activationCheckpointing: true, optimizer: 'adamw', }, cluster: h100Cluster(8, 4), parallelism: { tp: 8, pp: 1, cp: 1, ep: 1, distributedOptimizer: false, fsdpShardGroupSize: 0, zeroStage: 0, }, }, } const cloneModel = (model: ModelConfig): ModelConfig => ({ ...model, attentionProfile: model.attentionProfile ? { ...model.attentionProfile } : undefined, moe: model.moe ? { ...model.moe } : undefined, }) const cloneTraining = (training: TrainingConfig): TrainingConfig => ({ ...training }) const cloneCluster = (cluster: ClusterConfig): ClusterConfig => ({ ...cluster }) const cloneParallelism = (parallelism: ParallelismConfig): ParallelismConfig => ({ ...parallelism, }) export const cloneWorkbenchConfig = (config: WorkbenchConfig): WorkbenchConfig => ({ examplePresetId: config.examplePresetId, phaseId: config.phaseId, customized: config.customized, model: cloneModel(config.model), training: cloneTraining(config.training), cluster: cloneCluster(config.cluster), parallelism: cloneParallelism(config.parallelism), }) export function getScenarioWorkbenchConfig(scenario: WorkbenchScenarioId) { return cloneWorkbenchConfig(SCENARIOS[scenario]) } export function getExamplePresetOptions() { return Object.entries(EXAMPLE_PRESETS) .filter(([id]) => id !== 'llama31-405b') .map(([id, preset]) => ({ id: id as ExamplePresetId, label: preset.label, })) } export function getPhaseOptions(examplePresetId: ExamplePresetId) { const preset = EXAMPLE_PRESETS[examplePresetId] return Object.keys(preset.phases).map((phaseId) => ({ id: phaseId as ExamplePhaseId, label: phaseId === 'pretraining' ? 'Pretraining' : 'Long-context', })) } export function getExampleLabel(examplePresetId: ExamplePresetId) { return EXAMPLE_PRESETS[examplePresetId].label } export function getGpuPresetOptions() { return Object.entries(GPU_PRESETS).map(([id, preset]) => ({ id: id as GpuPresetId, label: preset.label, })) } export function getGpuPresetId(gpuType: GPUSpec): GpuPresetId | 'custom' { for (const [id, preset] of Object.entries(GPU_PRESETS)) { if (gpuPresetMatches(gpuType, preset.spec())) { return id as GpuPresetId } } return 'custom' } export function applyGpuPreset(config: WorkbenchConfig, gpuPresetId: GpuPresetId): WorkbenchConfig { return { ...config, customized: true, cluster: { ...config.cluster, gpuType: GPU_PRESETS[gpuPresetId].spec(), }, } } export function applyExamplePreset( _config: WorkbenchConfig, examplePresetId: ExamplePresetId, ): WorkbenchConfig { return createWorkbenchConfig(examplePresetId, 'pretraining') } export function applyExamplePhase( config: WorkbenchConfig, phaseId: ExamplePhaseId, ): WorkbenchConfig { return createWorkbenchConfig(config.examplePresetId, phaseId) } export function getFactorOptions(total: number, currentValue: number) { const factors = new Set([currentValue]) for (let candidate = 1; candidate <= total; candidate += 1) { if (total % candidate === 0) { factors.add(candidate) } } return Array.from(factors).sort((left, right) => left - right) }