| from typing import Dict |
| import torch |
| import torch.nn as nn |
| from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin |
| from equi_diffpo.model.common.normalizer import LinearNormalizer |
|
|
| class BaseLowdimPolicy(ModuleAttrMixin): |
| |
| |
| def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| obs_dict: |
| obs: B,To,Do |
| return: |
| action: B,Ta,Da |
| To = 3 |
| Ta = 4 |
| T = 6 |
| |o|o|o| |
| | | |a|a|a|a| |
| |o|o| |
| | |a|a|a|a|a| |
| | | | | |a|a| |
| """ |
| raise NotImplementedError() |
|
|
| |
| def reset(self): |
| pass |
|
|
| |
| |
| def set_normalizer(self, normalizer: LinearNormalizer): |
| raise NotImplementedError() |
|
|
| |