|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from cosmos_predict1.diffusion.training.utils.checkpointer import MultiRankCheckpointer |
|
|
from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer |
|
|
from cosmos_predict1.utils.trainer import Trainer as BaseTrainer |
|
|
|
|
|
|
|
|
class Trainer(BaseTrainer): |
|
|
def __init__(self, config): |
|
|
super(Trainer, self).__init__(config) |
|
|
if config.trainer.distributed_parallelism == "ddp": |
|
|
self.checkpointer = MultiRankCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) |
|
|
elif config.trainer.distributed_parallelism == "fsdp": |
|
|
self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) |
|
|
else: |
|
|
raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") |
|
|
|