trioskosmos commited on
Commit
cd9582d
·
verified ·
1 Parent(s): 764819e

Upload ai/utils/check_action_dist.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/check_action_dist.py +53 -0
ai/utils/check_action_dist.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+
4
+ import numpy as np
5
+
6
+
7
+ def analyze_distribution(pattern):
8
+ files = glob.glob(pattern)
9
+ if not files:
10
+ print(f"No files found for {pattern}")
11
+ return
12
+
13
+ total_samples = 0
14
+ action_counts = {}
15
+
16
+ print(f"Analyzing {len(files)} files...")
17
+ for f in files[:20]: # Check first 20 chunks for quick insight
18
+ data = np.load(f)
19
+ policies = data["policies"]
20
+
21
+ # policies is [N, 2000]
22
+ # Get argmax for each sample
23
+ best_actions = np.argmax(policies, axis=1)
24
+
25
+ total_samples += len(best_actions)
26
+ unique, counts = np.unique(best_actions, return_counts=True)
27
+ for a, c in zip(unique, counts):
28
+ action_counts[a] = action_counts.get(a, 0) + c
29
+
30
+ print(f"\nTotal Samples checked: {total_samples}")
31
+ sorted_actions = sorted(action_counts.items(), key=lambda x: x[1], reverse=True)
32
+
33
+ print("\nTop 10 Actions:")
34
+ for action, count in sorted_actions[:10]:
35
+ percentage = (count / total_samples) * 100
36
+ print(f"Action {action:4}: {count:6} samples ({percentage:5.1f}%)")
37
+
38
+ # Specifically check Action 0 (End Phase)
39
+ pass_count = action_counts.get(0, 0)
40
+ pass_perc = (pass_count / total_samples) * 100
41
+ print(f"\nAction 0 (Pass/End): {pass_perc:.1f}%")
42
+
43
+ # Check "tactical" actions (non-zero)
44
+ tactical_count = total_samples - pass_count
45
+ tactical_perc = (tactical_count / total_samples) * 100
46
+ print(f"Tactical Actions: {tactical_perc:.1f}%")
47
+
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--data", type=str, default="ai/data/self_play_0_chunk_*.npz")
52
+ args = parser.parse_args()
53
+ analyze_distribution(args.data)