trioskosmos commited on
Commit
0f9c2a1
·
verified ·
1 Parent(s): cd9582d

Upload ai/utils/debug_agent.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/debug_agent.py +54 -0
ai/utils/debug_agent.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ # Add project root to path
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
9
+
10
+ from ai.models.training_config import POLICY_SIZE
11
+ from ai.training.train import AlphaNet
12
+
13
+
14
+ def debug_model(model_path, data_path):
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ checkpoint = torch.load(model_path, map_location=device)
17
+
18
+ if isinstance(checkpoint, dict) and "model_state" in checkpoint:
19
+ state_dict = checkpoint["model_state"]
20
+ else:
21
+ state_dict = checkpoint
22
+
23
+ model = AlphaNet(policy_size=POLICY_SIZE).to(device)
24
+ model.load_state_dict(state_dict)
25
+ model.eval()
26
+
27
+ print(f"Loading data from {data_path}...")
28
+ data = np.load(data_path)
29
+ states = data["states"][:5]
30
+ true_policies = data["policies"][:5]
31
+
32
+ for i in range(len(states)):
33
+ state = torch.FloatTensor(states[i]).unsqueeze(0).to(device)
34
+ with torch.no_grad():
35
+ p_logits, v = model(state)
36
+ p_probs = torch.softmax(p_logits, dim=1)
37
+
38
+ print(f"\nSample {i}:")
39
+ print(f"Value prediction: {v.item():.4f}")
40
+
41
+ # Check Top-5 predicted actions
42
+ top_probs, top_actions = torch.topk(p_probs, 5)
43
+ print("Top 5 Predictions:")
44
+ for j in range(5):
45
+ print(f" Action {top_actions[0][j].item()}: {top_probs[0][j].item():.1%}")
46
+
47
+ # Check ground truth Top-1
48
+ gt_action = np.argmax(true_policies[i])
49
+ gt_prob = true_policies[i][gt_action]
50
+ print(f"Ground Truth Action {gt_action} with weight {gt_prob:.1%}")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ debug_model("ai/models/alphanet_best.pt", "ai/data/data_batch_0.npz")