|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.diffusion.training.trainer import Trainer |
|
|
from cosmos_predict1.utils.callback import LowPrecisionCallback as BaseCallback |
|
|
from cosmos_predict1.utils.config import Config |
|
|
from cosmos_predict1.utils.model import Model |
|
|
|
|
|
|
|
|
class LowPrecisionCallback(BaseCallback): |
|
|
""" |
|
|
Config with non-primitive type makes it difficult to override the option. |
|
|
The callback gets precision from model.precision instead. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Config, trainer: Trainer, update_iter: int): |
|
|
self.config = config |
|
|
self.trainer = trainer |
|
|
self.update_iter = update_iter |
|
|
|
|
|
def on_train_start(self, model: Model, iteration: int = 0) -> None: |
|
|
assert model.precision in [ |
|
|
torch.bfloat16, |
|
|
torch.float16, |
|
|
torch.half, |
|
|
], "LowPrecisionCallback must use a low precision dtype." |
|
|
self.precision_type = model.precision |
|
|
|