DaertML commited on
Commit
bd918e6
·
verified ·
1 Parent(s): 1e03148

Upload 2 files

Browse files
Files changed (2) hide show
  1. cnn_eval.py +93 -0
  2. cnn_train.py +112 -0
cnn_eval.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gymnasium as gym
4
+ import numpy as np
5
+
6
+ # --- 1. Re-defining the exact architecture from your training script ---
7
+ class BlackjackCNN(nn.Module):
8
+ def __init__(self):
9
+ super(BlackjackCNN, self).__init__()
10
+ self.conv = nn.Sequential(
11
+ nn.Conv2d(1, 16, kernel_size=2, stride=1, padding=1),
12
+ nn.ReLU(),
13
+ nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=1),
14
+ nn.ReLU()
15
+ )
16
+ self.fc = nn.Sequential(
17
+ nn.Flatten(),
18
+ nn.Linear(800, 64),
19
+ nn.ReLU(),
20
+ nn.Linear(64, 2)
21
+ )
22
+
23
+ def forward(self, x):
24
+ x = self.conv(x)
25
+ return self.fc(x)
26
+
27
+ def preprocess_state(state):
28
+ """
29
+ State: (Player Sum, Dealer Card, Useable Ace)
30
+ Normalization: Player/31, Dealer/10, Ace(0 or 1)
31
+ """
32
+ grid = np.zeros((3, 3))
33
+ grid[0, 0] = state[0] / 31.0
34
+ grid[1, 1] = state[1] / 10.0
35
+ grid[2, 2] = 1.0 if state[2] else 0.0
36
+ return torch.FloatTensor(grid).view(1, 1, 3, 3)
37
+
38
+ def test_cnn(path="blackjack_cnn.pth", num_rounds=1000):
39
+ env = gym.make('Blackjack-v1')
40
+ model = BlackjackCNN()
41
+
42
+ # Load the weights
43
+ try:
44
+ model.load_state_dict(torch.load(path))
45
+ model.eval()
46
+ print(f"Successfully loaded: {path}")
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ return
50
+
51
+ wins = 0
52
+ draws = 0
53
+ losses = 0
54
+
55
+ print(f"\nEvaluating CNN for {num_rounds} rounds...")
56
+
57
+ for i in range(num_rounds):
58
+ obs, _ = env.reset()
59
+ done = False
60
+
61
+ # Log the first 5 rounds to see what's happening
62
+ if i < 5:
63
+ print(f"\nRound {i+1} Start: Player={obs[0]}, Dealer={obs[1]}, Ace={obs[2]}")
64
+
65
+ while not done:
66
+ state_img = preprocess_state(obs)
67
+ with torch.no_grad():
68
+ q_values = model(state_img)
69
+ action = q_values.argmax().item()
70
+
71
+ action_name = "HIT" if action == 1 else "STICK"
72
+ obs, reward, terminated, truncated, _ = env.step(action)
73
+ done = terminated or truncated
74
+
75
+ if i < 5:
76
+ print(f" -> Action: {action_name} | Next State: {obs[0]} | Reward: {reward}")
77
+
78
+ if reward > 0:
79
+ wins += 1
80
+ elif reward == 0:
81
+ draws += 1
82
+ else:
83
+ losses += 1
84
+
85
+ print("-" * 30)
86
+ print(f"RESULTS FOR CNN ALONE:")
87
+ print(f"Wins: {wins} ({wins/num_rounds:.1%})")
88
+ print(f"Draws: {draws} ({draws/num_rounds:.1%})")
89
+ print(f"Losses: {losses} ({losses/num_rounds:.1%})")
90
+ print("-" * 30)
91
+
92
+ if __name__ == "__main__":
93
+ test_cnn()
cnn_train.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import numpy as np
6
+ import random
7
+ from collections import deque
8
+
9
+ # --- Hyperparameters ---
10
+ LEARNING_RATE = 0.001
11
+ GAMMA = 0.95
12
+ EPSILON_START = 1.0
13
+ EPSILON_END = 0.01
14
+ EPSILON_DECAY = 0.995
15
+ MEMORY_SIZE = 10000
16
+ BATCH_SIZE = 64
17
+ EPISODES = 1000
18
+ MODEL_PATH = "blackjack_cnn.pth" # Local filename
19
+
20
+ class BlackjackCNN(nn.Module):
21
+ def __init__(self):
22
+ super(BlackjackCNN, self).__init__()
23
+ self.conv = nn.Sequential(
24
+ nn.Conv2d(1, 16, kernel_size=2, stride=1, padding=1),
25
+ nn.ReLU(),
26
+ nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=1),
27
+ nn.ReLU()
28
+ )
29
+ self.fc = nn.Sequential(
30
+ nn.Flatten(),
31
+ nn.Linear(800, 64),
32
+ nn.ReLU(),
33
+ nn.Linear(64, 2)
34
+ )
35
+
36
+ def forward(self, x):
37
+ x = self.conv(x)
38
+ return self.fc(x)
39
+
40
+ def preprocess_state(state):
41
+ grid = np.zeros((3, 3))
42
+ grid[0, 0] = state[0] / 31.0
43
+ grid[1, 1] = state[1] / 10.0
44
+ grid[2, 2] = 1.0 if state[2] else 0.0
45
+ return torch.FloatTensor(grid).view(1, 1, 3, 3)
46
+
47
+ # --- Training Loop ---
48
+ env = gym.make('Blackjack-v1')
49
+ policy_net = BlackjackCNN()
50
+ optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
51
+ memory = deque(maxlen=MEMORY_SIZE)
52
+ epsilon = EPSILON_START
53
+
54
+ print(f"Starting training for {EPISODES} episodes...")
55
+
56
+ for episode in range(EPISODES):
57
+ obs, info = env.reset()
58
+ state_img = preprocess_state(obs)
59
+ done = False
60
+
61
+ while not done:
62
+ if random.random() < epsilon:
63
+ action = env.action_space.sample()
64
+ else:
65
+ with torch.no_grad():
66
+ action = policy_net(state_img).argmax().item()
67
+
68
+ next_obs, reward, terminated, truncated, info = env.step(action)
69
+ done = terminated or truncated
70
+ next_state_img = preprocess_state(next_obs)
71
+
72
+ memory.append((state_img, action, reward, next_state_img, done))
73
+ state_img = next_state_img
74
+
75
+ if len(memory) > BATCH_SIZE:
76
+ batch = random.sample(memory, BATCH_SIZE)
77
+ states, actions, rewards, next_states, dones = zip(*batch)
78
+
79
+ states = torch.cat(states)
80
+ actions = torch.tensor(actions).unsqueeze(1)
81
+ rewards = torch.tensor(rewards).float()
82
+ next_states = torch.cat(next_states)
83
+ dones = torch.tensor(dones).float()
84
+
85
+ current_q = policy_net(states).gather(1, actions)
86
+ next_q = policy_net(next_states).max(1)[0].detach()
87
+ target_q = rewards + (GAMMA * next_q * (1 - dones))
88
+
89
+ loss = nn.MSELoss()(current_q.squeeze(), target_q)
90
+ optimizer.zero_grad()
91
+ loss.backward()
92
+ optimizer.step()
93
+
94
+ epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
95
+
96
+ if (episode + 1) % 100 == 0:
97
+ print(f"Episode {episode + 1} | Epsilon: {epsilon:.2f}")
98
+
99
+ # --- Save the Model ---
100
+ torch.save(policy_net.state_dict(), MODEL_PATH)
101
+ print(f"\nModel saved locally to {MODEL_PATH}")
102
+
103
+ # --- Quick Test ---
104
+ print("\nTesting saved model for 5 rounds:")
105
+ policy_net.eval() # Set to evaluation mode
106
+ for i in range(5):
107
+ obs, _ = env.reset()
108
+ state_img = preprocess_state(obs)
109
+ with torch.no_grad():
110
+ action = policy_net(state_img).argmax().item()
111
+ action_name = "HIT" if action == 1 else "STICK"
112
+ print(f"Round {i+1}: Hand={obs[0]}, Dealer={obs[1]}, Action={action_name}")