Blackjack-CNN / cnn_eval.py
DaertML's picture
Upload 2 files
bd918e6 verified
import torch
import torch.nn as nn
import gymnasium as gym
import numpy as np
# --- 1. Re-defining the exact architecture from your training script ---
class BlackjackCNN(nn.Module):
def __init__(self):
super(BlackjackCNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=2, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=1),
nn.ReLU()
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(800, 64),
nn.ReLU(),
nn.Linear(64, 2)
)
def forward(self, x):
x = self.conv(x)
return self.fc(x)
def preprocess_state(state):
"""
State: (Player Sum, Dealer Card, Useable Ace)
Normalization: Player/31, Dealer/10, Ace(0 or 1)
"""
grid = np.zeros((3, 3))
grid[0, 0] = state[0] / 31.0
grid[1, 1] = state[1] / 10.0
grid[2, 2] = 1.0 if state[2] else 0.0
return torch.FloatTensor(grid).view(1, 1, 3, 3)
def test_cnn(path="blackjack_cnn.pth", num_rounds=1000):
env = gym.make('Blackjack-v1')
model = BlackjackCNN()
# Load the weights
try:
model.load_state_dict(torch.load(path))
model.eval()
print(f"Successfully loaded: {path}")
except Exception as e:
print(f"Error loading model: {e}")
return
wins = 0
draws = 0
losses = 0
print(f"\nEvaluating CNN for {num_rounds} rounds...")
for i in range(num_rounds):
obs, _ = env.reset()
done = False
# Log the first 5 rounds to see what's happening
if i < 5:
print(f"\nRound {i+1} Start: Player={obs[0]}, Dealer={obs[1]}, Ace={obs[2]}")
while not done:
state_img = preprocess_state(obs)
with torch.no_grad():
q_values = model(state_img)
action = q_values.argmax().item()
action_name = "HIT" if action == 1 else "STICK"
obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
if i < 5:
print(f" -> Action: {action_name} | Next State: {obs[0]} | Reward: {reward}")
if reward > 0:
wins += 1
elif reward == 0:
draws += 1
else:
losses += 1
print("-" * 30)
print(f"RESULTS FOR CNN ALONE:")
print(f"Wins: {wins} ({wins/num_rounds:.1%})")
print(f"Draws: {draws} ({draws/num_rounds:.1%})")
print(f"Losses: {losses} ({losses/num_rounds:.1%})")
print("-" * 30)
if __name__ == "__main__":
test_cnn()