gr00t1.5_starforce / tests /eval_mse.py
nnh-pbbb's picture
Add files using upload-large-folder tool
11235a4 verified
import warnings
from dataclasses import dataclass, field
from typing import List, Literal
import numpy as np
import tyro
from starforce.data.dataset import LeRobotSingleDataset
from starforce.data.embodiment_tags import EMBODIMENT_TAG_MAPPING
# from starforce.eval.robot import RobotInferenceClient
from starforce.experiment.data_config import DATA_CONFIG_MAP
from starforce.model.policy import BasePolicy, Gr00tPolicy
from starforce.utils.eval import calc_mse_for_single_trajectory
warnings.simplefilter("ignore", category=FutureWarning)
"""
Example command:
NOTE: provide --model_path to load up the model checkpoint in this script,
else it will use the default host and port via RobotInferenceClient
python tests/eval_mse.py --plot --model-path outputs/gr00t-3b-sl-3/
"""
@dataclass
class ArgsConfig:
"""Configuration for evaluating a policy."""
host: str = "localhost"
"""Host to connect to."""
port: int = 5555
"""Port to connect to."""
plot: bool = True
"""Whether to plot the images."""
modality_keys: List[str] = field(
default_factory=lambda: ["left_arm", "left_gripper", "right_arm", "right_gripper"]
)
"""Modality keys to evaluate."""
data_config: Literal[tuple(DATA_CONFIG_MAP.keys())] = "bimanual_agilex"
"""Data config to use."""
steps: int = 150
"""Number of steps to evaluate."""
trajs: int = 1
"""Number of trajectories to evaluate."""
action_horizon: int = None
"""Action horizon to evaluate. If None, will use the data config's action horizon."""
video_backend: Literal["decord", "torchvision_av"] = "torchvision_av"
"""Video backend to use for various codec options. h264: decord or av: torchvision_av"""
dataset_path: str = "data/sl/0721pre_data_v2"
"""Path to the dataset."""
embodiment_tag: Literal[tuple(EMBODIMENT_TAG_MAPPING.keys())] = "new_embodiment"
"""Embodiment tag to use."""
model_path: str = None
"""Path to the model checkpoint."""
denoising_steps: int = 4
"""Number of denoising steps to use."""
save_plot_path: str = "outputs/mse_vis.png"
"""Path to save the plot."""
def main(args: ArgsConfig):
data_config = DATA_CONFIG_MAP[args.data_config]
# Set action_horizon from data config if not provided
if args.action_horizon is None:
args.action_horizon = len(data_config.action_indices)
print(f"Using action_horizon={args.action_horizon} from data config '{args.data_config}'")
if args.model_path is not None:
import torch
modality_config = data_config.modality_config()
modality_transform = data_config.transform()
policy: BasePolicy = Gr00tPolicy(
model_path=args.model_path,
modality_config=modality_config,
modality_transform=modality_transform,
embodiment_tag=args.embodiment_tag,
denoising_steps=args.denoising_steps,
device="cuda" if torch.cuda.is_available() else "cpu",
)
else:
policy: BasePolicy = RobotInferenceClient(host=args.host, port=args.port)
# Get the supported modalities for the policy
modality = policy.get_modality_config()
print("Current modality config: \n", modality)
# Create the dataset
dataset = LeRobotSingleDataset(
dataset_path=args.dataset_path,
modality_configs=modality,
video_backend=args.video_backend,
video_backend_kwargs=None,
transforms=None, # We'll handle transforms separately through the policy
embodiment_tag=args.embodiment_tag,
)
print(len(dataset))
# Make a prediction
obs = dataset[0]
for k, v in obs.items():
if isinstance(v, np.ndarray):
print(k, v.shape)
else:
print(k, v)
for k, v in dataset.get_step_data(0, 0).items():
if isinstance(v, np.ndarray):
print(k, v.shape)
else:
print(k, v)
print("Total trajectories:", len(dataset.trajectory_lengths))
print("All trajectories:", dataset.trajectory_lengths)
print("Running on all trajs with modality keys:", args.modality_keys)
all_mse = []
for traj_id in range(args.trajs):
print("Running trajectory:", traj_id)
mse = calc_mse_for_single_trajectory(
policy,
dataset,
traj_id,
modality_keys=args.modality_keys,
steps=args.steps,
action_horizon=args.action_horizon,
plot=args.plot,
save_plot_path=args.save_plot_path,
)
print("MSE:", mse)
all_mse.append(mse)
print("Average MSE across all trajs:", np.mean(all_mse))
print("Done")
exit()
if __name__ == "__main__":
# Parse arguments using tyro
config = tyro.cli(ArgsConfig)
main(config)