world_model / wm /dataset /test_load_pt.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import torch
import tensordict
import os
from pathlib import Path
pt_file = "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4/mixed-large/acrobot-swingup.pt"
print(f"Loading {pt_file}...")
try:
data = torch.load(pt_file, map_location='cpu', weights_only=False)
print("Keys:", data.keys())
if 'obs' in data:
print("Obs keys:", data['obs'].keys() if hasattr(data['obs'], 'keys') else "Not a dict")
if hasattr(data['obs'], 'shape'):
print("Obs shape:", data['obs'].shape)
print("Action shape:", data['action'].shape)
print("Episode shape:", data['episode'].shape)
except Exception as e:
print(f"Error: {e}")