import torch import tensordict import os from pathlib import Path pt_file = "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4/mixed-large/acrobot-swingup.pt" print(f"Loading {pt_file}...") try: data = torch.load(pt_file, map_location='cpu', weights_only=False) print("Keys:", data.keys()) if 'obs' in data: print("Obs keys:", data['obs'].keys() if hasattr(data['obs'], 'keys') else "Not a dict") if hasattr(data['obs'], 'shape'): print("Obs shape:", data['obs'].shape) print("Action shape:", data['action'].shape) print("Episode shape:", data['episode'].shape) except Exception as e: print(f"Error: {e}")