|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Optional |
|
|
|
|
|
import attrs |
|
|
import torch |
|
|
from hydra.core.config_store import ConfigStore |
|
|
|
|
|
from cosmos_predict1.diffusion.conditioner import VideoExtendCondition, VideoExtendConditioner |
|
|
from cosmos_predict1.diffusion.config.base.conditioner import ( |
|
|
FPSConfig, |
|
|
ImageSizeConfig, |
|
|
NumFramesConfig, |
|
|
PaddingMaskConfig, |
|
|
ReMapkey, |
|
|
TextConfig, |
|
|
VideoCondBoolConfig, |
|
|
) |
|
|
from cosmos_predict1.diffusion.training.config.video2world_action.experiment import register_experiments |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ActionConditionalVideoExtendCondition(VideoExtendCondition): |
|
|
action: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
class ActionConditionalVideoExtendConditioner(VideoExtendConditioner): |
|
|
def forward( |
|
|
self, |
|
|
batch: Dict, |
|
|
override_dropout_rate: Optional[Dict[str, float]] = None, |
|
|
) -> ActionConditionalVideoExtendCondition: |
|
|
output = super()._forward(batch, override_dropout_rate) |
|
|
assert "action" in batch, "ActionConditionalVideoExtendConditioner requires 'action' in batch" |
|
|
output["action"] = batch["action"] |
|
|
return ActionConditionalVideoExtendCondition(**output) |
|
|
|
|
|
|
|
|
@attrs.define(slots=False) |
|
|
class ActionConfig: |
|
|
""" |
|
|
Remap the key from the input dictionary to the output dictionary. For `action`. |
|
|
""" |
|
|
|
|
|
obj: LazyDict = L(ReMapkey)(output_key="action", dtype=None) |
|
|
dropout_rate: float = 0.0 |
|
|
input_key: str = "action" |
|
|
|
|
|
|
|
|
ActionConditionalVideoExtendConditionerConfig: LazyDict = L(ActionConditionalVideoExtendConditioner)( |
|
|
text=TextConfig(), |
|
|
fps=FPSConfig(), |
|
|
num_frames=NumFramesConfig(), |
|
|
image_size=ImageSizeConfig(), |
|
|
padding_mask=PaddingMaskConfig(), |
|
|
video_cond_bool=VideoCondBoolConfig(), |
|
|
action=ActionConfig(), |
|
|
) |
|
|
|
|
|
|
|
|
def register_configs(): |
|
|
cs = ConfigStore.instance() |
|
|
|
|
|
register_experiments(cs) |
|
|
|
|
|
cs.store( |
|
|
group="conditioner", |
|
|
package="model.conditioner", |
|
|
name="action_conditional_video_cond", |
|
|
node=ActionConditionalVideoExtendConditionerConfig, |
|
|
) |
|
|
|