lsnu commited on
Commit
9ccea90
·
verified ·
1 Parent(s): 1d349b1

Add TWIN preprocessing and norm-stats helper scripts

Browse files
openpi/scripts/compute_norm_stats_repo.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute normalization statistics for a config with repo/loader overrides.
2
+
3
+ This exists so multiprocessing workers can respawn from a real on-disk module
4
+ instead of failing when the parent process was launched from stdin.
5
+ """
6
+
7
+ import dataclasses
8
+
9
+ import numpy as np
10
+ import tqdm
11
+ import tyro
12
+
13
+ import openpi.shared.normalize as normalize
14
+ import openpi.training.config as config_lib
15
+ import compute_norm_stats as base_compute
16
+
17
+
18
+ def main(
19
+ config_name: str,
20
+ repo_id: str,
21
+ batch_size: int = 64,
22
+ num_workers: int = 12,
23
+ assets_base_dir: str = "./assets",
24
+ max_frames: int | None = None,
25
+ ):
26
+ config = dataclasses.replace(
27
+ config_lib.get_config(config_name),
28
+ batch_size=batch_size,
29
+ num_workers=num_workers,
30
+ assets_base_dir=assets_base_dir,
31
+ )
32
+ data_factory = dataclasses.replace(config.data, repo_id=repo_id)
33
+ data_config = data_factory.create(config.assets_dirs, config.model)
34
+
35
+ if data_config.rlds_data_dir is not None:
36
+ data_loader, num_batches = base_compute.create_rlds_dataloader(
37
+ data_config,
38
+ config.model.action_horizon,
39
+ config.batch_size,
40
+ max_frames,
41
+ )
42
+ else:
43
+ data_loader, num_batches = base_compute.create_torch_dataloader(
44
+ data_config,
45
+ config.model.action_horizon,
46
+ config.batch_size,
47
+ config.model,
48
+ config.num_workers,
49
+ max_frames,
50
+ )
51
+
52
+ keys = ["state", "actions"]
53
+ stats = {key: normalize.RunningStats() for key in keys}
54
+
55
+ for batch in tqdm.tqdm(data_loader, total=num_batches, desc=f"{config_name} :: {repo_id}"):
56
+ for key in keys:
57
+ stats[key].update(np.asarray(batch[key]))
58
+
59
+ output_path = config.assets_dirs / repo_id
60
+ print(f"Writing stats to: {output_path}")
61
+ normalize.save(output_path, {key: value.get_statistics() for key, value in stats.items()})
62
+
63
+
64
+ if __name__ == "__main__":
65
+ tyro.cli(main)