| import torch | |
| import tensordict | |
| import os | |
| from pathlib import Path | |
| pt_file = "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4/mixed-large/acrobot-swingup.pt" | |
| print(f"Loading {pt_file}...") | |
| try: | |
| data = torch.load(pt_file, map_location='cpu', weights_only=False) | |
| print("Keys:", data.keys()) | |
| if 'obs' in data: | |
| print("Obs keys:", data['obs'].keys() if hasattr(data['obs'], 'keys') else "Not a dict") | |
| if hasattr(data['obs'], 'shape'): | |
| print("Obs shape:", data['obs'].shape) | |
| print("Action shape:", data['action'].shape) | |
| print("Episode shape:", data['episode'].shape) | |
| except Exception as e: | |
| print(f"Error: {e}") | |