Spaces:
Sleeping
Sleeping
| import { | |
| a100_80gb, | |
| b300, | |
| gb200, | |
| h100_sxm, | |
| llama31_405B, | |
| olmo3_32B, | |
| trinityLarge400B, | |
| type ClusterConfig, | |
| type GPUSpec, | |
| type ModelConfig, | |
| type ParallelismConfig, | |
| type TrainingConfig, | |
| } from './trainingClusterModel' | |
| export type ExamplePresetId = 'olmo3-32b' | 'llama31-405b' | 'trinity-large-400b' | |
| export type ExamplePhaseId = 'pretraining' | 'long-context' | |
| export type GpuPresetId = 'a100-80gb' | 'h100-sxm' | 'b300' | 'gb200' | |
| export type WorkbenchScenarioId = | |
| | 'default' | |
| | 'olmo-pretraining' | |
| | 'olmo-long-context' | |
| | 'llama-pretraining' | |
| | 'llama-long-context' | |
| | 'trinity-pretraining' | |
| | 'trinity-long-context' | |
| | 'infeasible-memory' | |
| export type WorkbenchConfig = { | |
| examplePresetId: ExamplePresetId | |
| phaseId: ExamplePhaseId | |
| customized: boolean | |
| model: ModelConfig | |
| training: TrainingConfig | |
| cluster: ClusterConfig | |
| parallelism: ParallelismConfig | |
| } | |
| type ExamplePhaseConfig = { | |
| cluster: ClusterConfig | |
| training: TrainingConfig | |
| parallelism: ParallelismConfig | |
| } | |
| type ExamplePreset = { | |
| label: string | |
| model: () => ModelConfig | |
| phases: Record<ExamplePhaseId, ExamplePhaseConfig> | |
| } | |
| const GPU_PRESETS: Record<GpuPresetId, { label: string; spec: () => GPUSpec }> = { | |
| 'a100-80gb': { | |
| label: 'A100 80GB', | |
| spec: a100_80gb, | |
| }, | |
| 'h100-sxm': { | |
| label: 'H100 SXM', | |
| spec: h100_sxm, | |
| }, | |
| b300: { | |
| label: 'B300', | |
| spec: b300, | |
| }, | |
| gb200: { | |
| label: 'GB200', | |
| spec: gb200, | |
| }, | |
| } | |
| const gpuPresetMatches = (candidate: GPUSpec, preset: GPUSpec) => | |
| candidate.name === preset.name && | |
| candidate.hbmCapacityGB === preset.hbmCapacityGB && | |
| candidate.peakTFLOPsBF16 === preset.peakTFLOPsBF16 && | |
| candidate.memBandwidthTBs === preset.memBandwidthTBs | |
| const h100Cluster = (numNodes: number, nodesPerRack: number): ClusterConfig => ({ | |
| gpuType: h100_sxm(), | |
| gpusPerNode: 8, | |
| numNodes, | |
| intraNodeBandwidthGBs: 900, | |
| interNodeBandwidthGBs: 50, | |
| nodesPerRack, | |
| rackLabel: 'rack', | |
| nodeLabel: 'GPU host', | |
| podLabel: 'rack', | |
| }) | |
| const b300Cluster = (numNodes: number, nodesPerRack: number): ClusterConfig => ({ | |
| gpuType: b300(), | |
| gpusPerNode: 8, | |
| numNodes, | |
| intraNodeBandwidthGBs: 900, | |
| interNodeBandwidthGBs: 50, | |
| nodesPerRack, | |
| rackLabel: 'rack', | |
| nodeLabel: 'GPU host', | |
| podLabel: 'rack', | |
| }) | |
| export const EXAMPLE_PRESETS: Record<ExamplePresetId, ExamplePreset> = { | |
| 'olmo3-32b': { | |
| label: 'OLMo 3 32B', | |
| model: olmo3_32B, | |
| phases: { | |
| pretraining: { | |
| cluster: h100Cluster(128, 16), | |
| training: { | |
| microBatchSize: 1, | |
| seqLength: 8192, | |
| gradAccumSteps: 1, | |
| precision: 'bf16', | |
| activationCheckpointing: true, | |
| optimizer: 'adamw', | |
| }, | |
| parallelism: { | |
| tp: 1, | |
| pp: 1, | |
| cp: 1, | |
| ep: 1, | |
| distributedOptimizer: true, | |
| fsdpShardGroupSize: 256, | |
| zeroStage: 3, | |
| }, | |
| }, | |
| 'long-context': { | |
| cluster: h100Cluster(32, 8), | |
| training: { | |
| microBatchSize: 1, | |
| seqLength: 65536, | |
| gradAccumSteps: 1, | |
| precision: 'bf16', | |
| activationCheckpointing: true, | |
| optimizer: 'adamw', | |
| }, | |
| parallelism: { | |
| tp: 1, | |
| pp: 1, | |
| cp: 8, | |
| ep: 1, | |
| distributedOptimizer: true, | |
| fsdpShardGroupSize: 256, | |
| zeroStage: 3, | |
| }, | |
| }, | |
| }, | |
| }, | |
| 'llama31-405b': { | |
| label: 'Llama 3.1 405B', | |
| model: llama31_405B, | |
| phases: { | |
| pretraining: { | |
| cluster: h100Cluster(2048, 16), | |
| training: { | |
| microBatchSize: 1, | |
| seqLength: 8192, | |
| gradAccumSteps: 16, | |
| precision: 'bf16', | |
| activationCheckpointing: true, | |
| optimizer: 'adamw', | |
| }, | |
| parallelism: { | |
| tp: 8, | |
| pp: 16, | |
| cp: 1, | |
| ep: 1, | |
| distributedOptimizer: true, | |
| fsdpShardGroupSize: 0, | |
| zeroStage: 1, | |
| }, | |
| }, | |
| 'long-context': { | |
| cluster: h100Cluster(2048, 16), | |
| training: { | |
| microBatchSize: 1, | |
| seqLength: 131072, | |
| gradAccumSteps: 1, | |
| precision: 'bf16', | |
| activationCheckpointing: true, | |
| optimizer: 'adamw', | |
| }, | |
| parallelism: { | |
| tp: 8, | |
| pp: 16, | |
| cp: 16, | |
| ep: 1, | |
| distributedOptimizer: true, | |
| fsdpShardGroupSize: 0, | |
| zeroStage: 1, | |
| }, | |
| }, | |
| }, | |
| }, | |
| 'trinity-large-400b': { | |
| label: 'Trinity Large 400B', | |
| model: trinityLarge400B, | |
| phases: { | |
| pretraining: { | |
| cluster: b300Cluster(256, 9), | |
| training: { | |
| microBatchSize: 1, | |
| seqLength: 8192, | |
| gradAccumSteps: 8, | |
| precision: 'bf16', | |
| activationCheckpointing: true, | |
| optimizer: 'muon', | |
| }, | |
| parallelism: { | |
| tp: 1, | |
| pp: 1, | |
| cp: 1, | |
| ep: 8, | |
| distributedOptimizer: true, | |
| fsdpShardGroupSize: 128, | |
| zeroStage: 3, | |
| }, | |
| }, | |
| 'long-context': { | |
| cluster: b300Cluster(256, 9), | |
| training: { | |
| microBatchSize: 1, | |
| seqLength: 262144, | |
| gradAccumSteps: 1, | |
| precision: 'bf16', | |
| activationCheckpointing: true, | |
| optimizer: 'muon', | |
| }, | |
| parallelism: { | |
| tp: 1, | |
| pp: 1, | |
| cp: 4, | |
| ep: 8, | |
| distributedOptimizer: true, | |
| fsdpShardGroupSize: 128, | |
| zeroStage: 3, | |
| }, | |
| }, | |
| }, | |
| }, | |
| } | |
| const createWorkbenchConfig = ( | |
| examplePresetId: ExamplePresetId, | |
| phaseId: ExamplePhaseId, | |
| ): WorkbenchConfig => { | |
| const preset = EXAMPLE_PRESETS[examplePresetId] | |
| const phase = preset.phases[phaseId] | |
| return { | |
| examplePresetId, | |
| phaseId, | |
| customized: false, | |
| model: preset.model(), | |
| training: { ...phase.training }, | |
| cluster: { ...phase.cluster }, | |
| parallelism: { ...phase.parallelism }, | |
| } | |
| } | |
| const SCENARIOS: Record<WorkbenchScenarioId, WorkbenchConfig> = { | |
| default: createWorkbenchConfig('olmo3-32b', 'pretraining'), | |
| 'olmo-pretraining': createWorkbenchConfig('olmo3-32b', 'pretraining'), | |
| 'olmo-long-context': createWorkbenchConfig('olmo3-32b', 'long-context'), | |
| 'llama-pretraining': createWorkbenchConfig('llama31-405b', 'pretraining'), | |
| 'llama-long-context': createWorkbenchConfig('llama31-405b', 'long-context'), | |
| 'trinity-pretraining': createWorkbenchConfig('trinity-large-400b', 'pretraining'), | |
| 'trinity-long-context': createWorkbenchConfig('trinity-large-400b', 'long-context'), | |
| 'infeasible-memory': { | |
| examplePresetId: 'llama31-405b', | |
| phaseId: 'pretraining', | |
| customized: false, | |
| model: llama31_405B(), | |
| training: { | |
| microBatchSize: 1, | |
| seqLength: 8192, | |
| gradAccumSteps: 1, | |
| precision: 'bf16', | |
| activationCheckpointing: true, | |
| optimizer: 'adamw', | |
| }, | |
| cluster: h100Cluster(8, 4), | |
| parallelism: { | |
| tp: 8, | |
| pp: 1, | |
| cp: 1, | |
| ep: 1, | |
| distributedOptimizer: false, | |
| fsdpShardGroupSize: 0, | |
| zeroStage: 0, | |
| }, | |
| }, | |
| } | |
| const cloneModel = (model: ModelConfig): ModelConfig => ({ | |
| ...model, | |
| attentionProfile: model.attentionProfile ? { ...model.attentionProfile } : undefined, | |
| moe: model.moe ? { ...model.moe } : undefined, | |
| }) | |
| const cloneTraining = (training: TrainingConfig): TrainingConfig => ({ ...training }) | |
| const cloneCluster = (cluster: ClusterConfig): ClusterConfig => ({ ...cluster }) | |
| const cloneParallelism = (parallelism: ParallelismConfig): ParallelismConfig => ({ | |
| ...parallelism, | |
| }) | |
| export const cloneWorkbenchConfig = (config: WorkbenchConfig): WorkbenchConfig => ({ | |
| examplePresetId: config.examplePresetId, | |
| phaseId: config.phaseId, | |
| customized: config.customized, | |
| model: cloneModel(config.model), | |
| training: cloneTraining(config.training), | |
| cluster: cloneCluster(config.cluster), | |
| parallelism: cloneParallelism(config.parallelism), | |
| }) | |
| export function getScenarioWorkbenchConfig(scenario: WorkbenchScenarioId) { | |
| return cloneWorkbenchConfig(SCENARIOS[scenario]) | |
| } | |
| export function getExamplePresetOptions() { | |
| return Object.entries(EXAMPLE_PRESETS) | |
| .filter(([id]) => id !== 'llama31-405b') | |
| .map(([id, preset]) => ({ | |
| id: id as ExamplePresetId, | |
| label: preset.label, | |
| })) | |
| } | |
| export function getPhaseOptions(examplePresetId: ExamplePresetId) { | |
| const preset = EXAMPLE_PRESETS[examplePresetId] | |
| return Object.keys(preset.phases).map((phaseId) => ({ | |
| id: phaseId as ExamplePhaseId, | |
| label: phaseId === 'pretraining' ? 'Pretraining' : 'Long-context', | |
| })) | |
| } | |
| export function getExampleLabel(examplePresetId: ExamplePresetId) { | |
| return EXAMPLE_PRESETS[examplePresetId].label | |
| } | |
| export function getGpuPresetOptions() { | |
| return Object.entries(GPU_PRESETS).map(([id, preset]) => ({ | |
| id: id as GpuPresetId, | |
| label: preset.label, | |
| })) | |
| } | |
| export function getGpuPresetId(gpuType: GPUSpec): GpuPresetId | 'custom' { | |
| for (const [id, preset] of Object.entries(GPU_PRESETS)) { | |
| if (gpuPresetMatches(gpuType, preset.spec())) { | |
| return id as GpuPresetId | |
| } | |
| } | |
| return 'custom' | |
| } | |
| export function applyGpuPreset(config: WorkbenchConfig, gpuPresetId: GpuPresetId): WorkbenchConfig { | |
| return { | |
| ...config, | |
| customized: true, | |
| cluster: { | |
| ...config.cluster, | |
| gpuType: GPU_PRESETS[gpuPresetId].spec(), | |
| }, | |
| } | |
| } | |
| export function applyExamplePreset( | |
| _config: WorkbenchConfig, | |
| examplePresetId: ExamplePresetId, | |
| ): WorkbenchConfig { | |
| return createWorkbenchConfig(examplePresetId, 'pretraining') | |
| } | |
| export function applyExamplePhase( | |
| config: WorkbenchConfig, | |
| phaseId: ExamplePhaseId, | |
| ): WorkbenchConfig { | |
| return createWorkbenchConfig(config.examplePresetId, phaseId) | |
| } | |
| export function getFactorOptions(total: number, currentValue: number) { | |
| const factors = new Set<number>([currentValue]) | |
| for (let candidate = 1; candidate <= total; candidate += 1) { | |
| if (total % candidate === 0) { | |
| factors.add(candidate) | |
| } | |
| } | |
| return Array.from(factors).sort((left, right) => left - right) | |
| } | |