CodyAI-Model / train.py
NeuralWolf's picture
Added all the files my model will need
9d9bb0d verified
Raw
History Blame Contribute Delete
3.61 kB
import numpy as np
import matplotlib.pyplot as plt
from snake_env import SnakeEnv
from Cod_agent import DQNAgent
def plot_results(scores, best_scores, avg_scores, epsilons):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7))
fig.suptitle("Snake AI — Training Results", fontsize=14, fontweight="bold")
games = list(range(1, len(scores) + 1))
# --- Top: scores ---
ax1.set_facecolor("#0f1923")
ax1.plot(games, scores, color="#888888", alpha=0.5,
linewidth=1, label="Score per game")
ax1.plot(games, best_scores, color="#2ed573",
linewidth=2, label="Best score")
ax1.plot(games, avg_scores, color="#ffa502",
linewidth=2, linestyle="--", label="Avg last 50 games")
ax1.set_ylabel("Score")
ax1.legend(loc="upper left")
ax1.grid(True, alpha=0.2)
# --- Bottom: epsilon ---
ax2.set_facecolor("#0f1923")
ax2.plot(games, epsilons, color="#7f77dd", linewidth=2, label="Epsilon")
ax2.set_ylabel("Epsilon")
ax2.set_xlabel("Game")
ax2.legend(loc="upper right")
ax2.grid(True, alpha=0.2)
plt.tight_layout()
plt.savefig("training_progress.png", dpi=150, bbox_inches="tight")
print("Graph saved → training_progress.png")
plt.show()
def train():
env = SnakeEnv(render=True) # no window during training
agent = DQNAgent()
num_games = 1000
best_score = 0
all_scores = []
best_scores = []
avg_scores = []
epsilons = []
print("Training started! Graph will show when done.\n")
print(f"{'Game':<8} {'Score':<8} {'Best':<8} {'Avg(50)':<10} {'Epsilon'}")
print("-" * 55)
for game in range(1, num_games + 1):
state = env.reset()
while True:
action = agent.act(state)
next_state, reward, done = env.step(action)
agent.remember(state, action, reward, next_state, done)
agent.learn()
state = next_state
if done:
break
# --- Record stats ---
all_scores.append(env.score)
if env.score > best_score:
best_score = env.score
agent.save("snake_brain.npz")
best_scores.append(best_score)
avg = float(np.mean(all_scores[-50:]))
avg_scores.append(avg)
epsilons.append(agent.epsilon)
if game % 10 == 0:
print(f"{game:<8} {env.score:<8} {best_score:<8} "
f"{avg:<10.2f} {agent.epsilon:.4f}")
print(f"\nTraining complete! Best score: {best_score}")
plot_results(all_scores, best_scores, avg_scores, epsilons)
def watch():
"""Load trained model and watch it play."""
env = SnakeEnv(render=True)
agent = DQNAgent()
agent.epsilon = 0.0
try:
agent.load("Cod_brain.npz")
except FileNotFoundError:
print("No saved model found! Train first: python train.py")
return
print("Watching AI play... (close window to stop)\n")
game = 0
while True:
state = env.reset()
game += 1
while True:
action = agent.act(state)
state, _, done = env.step(action)
if done:
print(f"Game {game} | Score: {env.score}")
break
# ================================================================
if __name__ == "__main__":
import sys
mode = sys.argv[1] if len(sys.argv) > 1 else "train"
if mode == "watch":
watch()
else:
train()