| | import dataclasses |
| | from typing import Protocol, runtime_checkable |
| |
|
| | import jax.numpy as jnp |
| | import optax |
| |
|
| | import openpi.shared.array_typing as at |
| |
|
| |
|
| | @runtime_checkable |
| | class LRScheduleConfig(Protocol): |
| | def create(self) -> optax.Schedule: ... |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class CosineDecaySchedule(LRScheduleConfig): |
| | """Cosine decay schedule with warmup.""" |
| |
|
| | warmup_steps: int = 1_000 |
| | peak_lr: float = 2.5e-5 |
| | decay_steps: int = 30_000 |
| | decay_lr: float = 2.5e-6 |
| |
|
| | def create(self) -> optax.Schedule: |
| | return optax.warmup_cosine_decay_schedule( |
| | init_value=self.peak_lr / (self.warmup_steps + 1), |
| | peak_value=self.peak_lr, |
| | warmup_steps=self.warmup_steps, |
| | decay_steps=self.decay_steps, |
| | end_value=self.decay_lr, |
| | ) |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class RsqrtDecaySchedule(LRScheduleConfig): |
| | """Inverse square root decay schedule with warmup.""" |
| |
|
| | warmup_steps: int = 1_000 |
| | peak_lr: float = 5e-5 |
| | timescale: float = 10_000 |
| |
|
| | def create(self) -> optax.Schedule: |
| | return optax.join_schedules( |
| | [ |
| | optax.linear_schedule( |
| | init_value=self.peak_lr / (self.warmup_steps + 1), |
| | end_value=self.peak_lr, |
| | transition_steps=self.warmup_steps, |
| | ), |
| | lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale), |
| | ], |
| | [self.warmup_steps], |
| | ) |
| |
|
| |
|
| | @runtime_checkable |
| | class OptimizerConfig(Protocol): |
| | def create( |
| | self, |
| | lr: optax.ScalarOrSchedule, |
| | weight_decay_mask: at.PyTree | None = None, |
| | ) -> optax.GradientTransformation: ... |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class AdamW(OptimizerConfig): |
| | """AdamW optimizer.""" |
| |
|
| | b1: float = 0.9 |
| | b2: float = 0.95 |
| | eps: float = 1e-8 |
| | weight_decay: float = 1e-10 |
| | clip_gradient_norm: float = 1.0 |
| |
|
| | def create( |
| | self, |
| | lr: optax.ScalarOrSchedule, |
| | weight_decay_mask: at.PyTree | None = None, |
| | ) -> optax.GradientTransformation: |
| | tx = optax.adamw( |
| | lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask |
| | ) |
| |
|
| | return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx) |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class SGD(OptimizerConfig): |
| | """SGD optimizer.""" |
| |
|
| | lr: float = 5e-5 |
| | momentum: float = 0.9 |
| | nesterov: bool = False |
| |
|
| | def create( |
| | self, |
| | lr: optax.ScalarOrSchedule, |
| | weight_decay_mask: at.PyTree | None = None, |
| | ) -> optax.GradientTransformation: |
| | assert weight_decay_mask is None, "Weight decay is not supported for SGD" |
| | return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov) |
| |
|
| |
|
| | def create_optimizer( |
| | optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None |
| | ) -> optax.GradientTransformation: |
| | lr = lr_schedule.create() |
| | return optimizer.create(lr, weight_decay_mask=weight_decay_mask) |
| |
|