| import warnings |
| from dataclasses import dataclass, field |
| from typing import List, Literal |
|
|
| import numpy as np |
| import tyro |
|
|
| from starforce.data.dataset import LeRobotSingleDataset |
| from starforce.data.embodiment_tags import EMBODIMENT_TAG_MAPPING |
|
|
| |
| from starforce.experiment.data_config import DATA_CONFIG_MAP |
| from starforce.model.policy import BasePolicy, Gr00tPolicy |
| from starforce.utils.eval import calc_mse_for_single_trajectory |
|
|
| warnings.simplefilter("ignore", category=FutureWarning) |
|
|
| """ |
| Example command: |
| |
| NOTE: provide --model_path to load up the model checkpoint in this script, |
| else it will use the default host and port via RobotInferenceClient |
| |
| python tests/eval_mse.py --plot --model-path outputs/gr00t-3b-sl-3/ |
| |
| """ |
|
|
|
|
| @dataclass |
| class ArgsConfig: |
| """Configuration for evaluating a policy.""" |
|
|
| host: str = "localhost" |
| """Host to connect to.""" |
|
|
| port: int = 5555 |
| """Port to connect to.""" |
|
|
| plot: bool = True |
| """Whether to plot the images.""" |
|
|
| modality_keys: List[str] = field( |
| default_factory=lambda: ["left_arm", "left_gripper", "right_arm", "right_gripper"] |
| ) |
| """Modality keys to evaluate.""" |
|
|
| data_config: Literal[tuple(DATA_CONFIG_MAP.keys())] = "bimanual_agilex" |
| """Data config to use.""" |
|
|
| steps: int = 150 |
| """Number of steps to evaluate.""" |
|
|
| trajs: int = 1 |
| """Number of trajectories to evaluate.""" |
|
|
| action_horizon: int = None |
| """Action horizon to evaluate. If None, will use the data config's action horizon.""" |
|
|
| video_backend: Literal["decord", "torchvision_av"] = "torchvision_av" |
| """Video backend to use for various codec options. h264: decord or av: torchvision_av""" |
|
|
| dataset_path: str = "data/sl/0721pre_data_v2" |
| """Path to the dataset.""" |
|
|
| embodiment_tag: Literal[tuple(EMBODIMENT_TAG_MAPPING.keys())] = "new_embodiment" |
| """Embodiment tag to use.""" |
|
|
| model_path: str = None |
| """Path to the model checkpoint.""" |
|
|
| denoising_steps: int = 4 |
| """Number of denoising steps to use.""" |
|
|
| save_plot_path: str = "outputs/mse_vis.png" |
| """Path to save the plot.""" |
|
|
|
|
| def main(args: ArgsConfig): |
| data_config = DATA_CONFIG_MAP[args.data_config] |
|
|
| |
| if args.action_horizon is None: |
| args.action_horizon = len(data_config.action_indices) |
| print(f"Using action_horizon={args.action_horizon} from data config '{args.data_config}'") |
|
|
| if args.model_path is not None: |
| import torch |
|
|
| modality_config = data_config.modality_config() |
| modality_transform = data_config.transform() |
|
|
| policy: BasePolicy = Gr00tPolicy( |
| model_path=args.model_path, |
| modality_config=modality_config, |
| modality_transform=modality_transform, |
| embodiment_tag=args.embodiment_tag, |
| denoising_steps=args.denoising_steps, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| else: |
| policy: BasePolicy = RobotInferenceClient(host=args.host, port=args.port) |
|
|
| |
| modality = policy.get_modality_config() |
| print("Current modality config: \n", modality) |
|
|
| |
| dataset = LeRobotSingleDataset( |
| dataset_path=args.dataset_path, |
| modality_configs=modality, |
| video_backend=args.video_backend, |
| video_backend_kwargs=None, |
| transforms=None, |
| embodiment_tag=args.embodiment_tag, |
| ) |
|
|
| print(len(dataset)) |
| |
| obs = dataset[0] |
| for k, v in obs.items(): |
| if isinstance(v, np.ndarray): |
| print(k, v.shape) |
| else: |
| print(k, v) |
|
|
| for k, v in dataset.get_step_data(0, 0).items(): |
| if isinstance(v, np.ndarray): |
| print(k, v.shape) |
| else: |
| print(k, v) |
|
|
| print("Total trajectories:", len(dataset.trajectory_lengths)) |
| print("All trajectories:", dataset.trajectory_lengths) |
| print("Running on all trajs with modality keys:", args.modality_keys) |
|
|
| all_mse = [] |
| for traj_id in range(args.trajs): |
| print("Running trajectory:", traj_id) |
| mse = calc_mse_for_single_trajectory( |
| policy, |
| dataset, |
| traj_id, |
| modality_keys=args.modality_keys, |
| steps=args.steps, |
| action_horizon=args.action_horizon, |
| plot=args.plot, |
| save_plot_path=args.save_plot_path, |
| ) |
| print("MSE:", mse) |
| all_mse.append(mse) |
| print("Average MSE across all trajs:", np.mean(all_mse)) |
| print("Done") |
| exit() |
|
|
|
|
| if __name__ == "__main__": |
| |
| config = tyro.cli(ArgsConfig) |
| main(config) |
|
|