File size: 6,109 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import h5py
import sys
import numpy as np

DEFAULT_PATH = "/data/hongzefu/data_0225/record_dataset_VideoUnmaskSwap.h5"

def _is_equal(a, b):
    if a is None and b is None:
        return True
    if a is None or b is None:
        return False
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return np.array_equal(a, b)
    try:
        return bool(np.all(a == b))
    except:
        return False

def inspect_actions(filepath, target_timestep=None, window=10):
    print(f"Inspecting HDF5 file: {filepath}")
    
    try:
        with h5py.File(filepath, 'r') as f:
            # Sort episodes naturally (e.g., episode_0, episode_1, etc.)
            episodes = [k for k in f.keys() if k.startswith('episode_')]
            episodes.sort(key=lambda x: int(x.split('_')[1]))
            
            for ep_name in episodes:
                ep_group = f[ep_name]
                print(f"--- {ep_name} ---")
                
                # Sort timesteps naturally
                timesteps = [k for k in ep_group.keys() if k.startswith('timestep_')]
                timesteps.sort(key=lambda x: int(x.split('_')[1]))
                
                last_choice_action = None
                last_is_subgoal_boundary = None
                last_waypoint_action = None
                last_is_video_demo = None
                skip_count = 0
                
                for ts_name in timesteps:
                    ts_group = ep_group[ts_name]
                    
                    # Try to find the requested fields
                    choice_action = None
                    is_subgoal_boundary = None
                    is_video_demo = None
                    waypoint_action = None
                    
                    # Action group typically contains choice_action and waypoint_action
                    if 'action' in ts_group:
                        action_group = ts_group['action']
                        if 'choice_action' in action_group:
                            choice_action = np.array(action_group['choice_action'])
                        if 'waypoint_action' in action_group:
                            waypoint_action = np.array(action_group['waypoint_action'])
                            
                    # info group typically contains is_subgoal_boundary and is_video_demo
                    if 'info' in ts_group:
                        info_group = ts_group['info']
                        if 'is_subgoal_boundary' in info_group:
                            is_subgoal_boundary = np.array(info_group['is_subgoal_boundary'])
                        if 'is_video_demo' in info_group:
                            is_video_demo = np.array(info_group['is_video_demo'])
                            
                    # For safety, check if they are directly under timestep or elsewhere
                    if choice_action is None and 'choice_action' in ts_group:
                        choice_action = np.array(ts_group['choice_action'])
                    if is_subgoal_boundary is None and 'is_subgoal_boundary' in ts_group:
                        is_subgoal_boundary = np.array(ts_group['is_subgoal_boundary'])
                    if is_video_demo is None and 'is_video_demo' in ts_group:
                        is_video_demo = np.array(ts_group['is_video_demo'])
                    if waypoint_action is None and 'waypoint_action' in ts_group:
                        waypoint_action = np.array(ts_group['waypoint_action'])

                    should_skip = False
                    if last_choice_action is not None:
                        same_choice = _is_equal(choice_action, last_choice_action)
                        same_boundary = _is_equal(is_subgoal_boundary, last_is_subgoal_boundary)
                        same_vd = _is_equal(is_video_demo, last_is_video_demo)
                        same_wp = _is_equal(waypoint_action, last_waypoint_action)
                        
                        should_skip = same_choice and same_boundary and same_vd and same_wp
                        
                        if target_timestep is not None:
                            ts_idx = int(ts_name.split('_')[1])
                            if abs(ts_idx - target_timestep) <= window:
                                should_skip = False
                        
                    if should_skip:
                        skip_count += 1
                        continue

                    if skip_count > 0:
                        print(f"  ... ({skip_count} identical timesteps skipped) ...")
                        skip_count = 0

                    print(f"  {ts_name}:")
                    print(f"    choice_action:   {choice_action}")
                    print(f"    is_subgoal_boundary: {is_subgoal_boundary}")
                    print(f"    is_video_demo:   {is_video_demo}")
                    print(f"    waypoint_action: {waypoint_action}")

                    last_choice_action = choice_action
                    last_is_subgoal_boundary = is_subgoal_boundary
                    last_is_video_demo = is_video_demo
                    last_waypoint_action = waypoint_action

                if skip_count > 0:
                    print(f"  ... ({skip_count} identical timesteps skipped) ...")
                    skip_count = 0

    except Exception as e:
        print(f"Error reading HDF5 file: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Inspect HDF5 dataset actions.")
    parser.add_argument("filepath", type=str, nargs="?", default=DEFAULT_PATH,
                        help="Path to the HDF5 file.")
    parser.add_argument("-t", "--timestep", type=int, default=None,
                        help="Specific timestep to not omit even if identical.")
    parser.add_argument("-w", "--window", type=int, default=10,
                        help="Window around specified timestep to not omit (default: 10).")
    args = parser.parse_args()

    inspect_actions(args.filepath, target_timestep=args.timestep, window=args.window)